-
-
Save mvandermeulen/c9869ddfe40c82cfcc4357c1cd6cfd8a to your computer and use it in GitHub Desktop.
Face Recognition
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fastapi import FastAPI, HTTPException, UploadFile, Form | |
from fastapi.middleware.cors import CORSMiddleware | |
from prisma import Prisma | |
from deepface import DeepFace | |
import numpy as np | |
from io import BytesIO | |
from PIL import Image, UnidentifiedImageError | |
import json | |
import os | |
import base64 | |
import uvicorn | |
# Suppress TensorFlow warnings | |
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" | |
app = FastAPI() | |
# Initialize Prisma Client | |
db = Prisma() | |
# CORS Settings | |
origins = ["https://localhost:5174", "https://localhost:5173"] | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=origins, | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
MODEL_NAME = "VGG-Face" # Model to use for embeddings | |
@app.on_event("startup") | |
async def startup(): | |
await db.connect() | |
@app.on_event("shutdown") | |
async def shutdown(): | |
await db.disconnect() | |
@app.get("/") | |
def backend_init(): | |
return {"backend": "Backend running successfully :))"} | |
# Utility function to calculate cosine similarity | |
def calculate_cosine_similarity(embedding1, embedding2): | |
"""Calculate cosine similarity between two embeddings.""" | |
embedding1 = np.array(embedding1) | |
embedding2 = np.array(embedding2) | |
return np.dot(embedding1, embedding2) / (np.linalg.norm(embedding1) * np.linalg.norm(embedding2)) | |
def preprocess_image(photo_bytes: bytes): | |
"""Preprocess image: Validate, resize, and normalize.""" | |
try: | |
# Open the image | |
photo = Image.open(BytesIO(photo_bytes)) | |
# Validate image mode and convert if necessary | |
if photo.mode != "RGB": | |
photo = photo.convert("RGB") | |
# Resize the image to a standard size (e.g., 224x224) for better processing | |
photo = photo.resize((224, 224)) | |
# Convert the image to a NumPy array | |
photo_array = np.array(photo) | |
return photo_array | |
except UnidentifiedImageError: | |
raise HTTPException(status_code=400, detail="Uploaded file is not a valid image") | |
def extract_face_embedding(photo_bytes: bytes): | |
"""Extract face embedding using DeepFace.""" | |
try: | |
# Preprocess the image | |
photo_array = preprocess_image(photo_bytes) | |
# Debugging: Log image array shape | |
print(f"Processing image with shape: {photo_array.shape}") | |
# Extract embedding using DeepFace | |
embedding = DeepFace.represent( | |
img_path=photo_array, model_name=MODEL_NAME, enforce_detection=False | |
) | |
return embedding[0]["embedding"] | |
except Exception as e: | |
print(f"Error extracting embedding: {str(e)}") # Debugging log | |
raise HTTPException(status_code=500, detail=f"Error extracting embedding: {str(e)}") | |
async def save_to_db(table: str, name: str, age: int, gender: str, photo_base64: str, embedding: list): | |
"""Save data to the database using Prisma.""" | |
embedding_json = json.dumps(embedding) | |
if table == "employee": | |
await db.employee.create( | |
data={ | |
"name": name, | |
"age": age, | |
"gender": gender, | |
"photoBase64": photo_base64, | |
"embedding": embedding_json | |
} | |
) | |
elif table == "visitor": | |
await db.visitor.create( | |
data={ | |
"name": name, | |
"age": age, | |
"gender": gender, | |
"photoBase64": photo_base64, | |
"embedding": embedding_json | |
} | |
) | |
@app.post("/register-employee/") | |
async def register_employee( | |
name: str = Form(...), age: int = Form(...), gender: str = Form(...), photo: UploadFile = None | |
): | |
try: | |
# Read and encode the photo | |
photo_bytes = await photo.read() | |
photo_base64 = base64.b64encode(photo_bytes).decode("utf-8") | |
# Extract embedding | |
embedding = extract_face_embedding(photo_bytes) | |
# Save to DB | |
await save_to_db("employee", name, age, gender, photo_base64, embedding) | |
return {"message": "Employee registered successfully"} | |
except Exception as e: | |
print(f"Error during employee registration: {str(e)}") # Debugging log | |
raise HTTPException(status_code=500, detail=f"Error during employee registration: {str(e)}") | |
@app.post("/register-visitor/") | |
async def register_visitor( | |
name: str = Form(...), age: int = Form(...), gender: str = Form(...), photo: UploadFile = None | |
): | |
try: | |
# Read and encode the photo | |
photo_bytes = await photo.read() | |
photo_base64 = base64.b64encode(photo_bytes).decode("utf-8") | |
# Extract embedding | |
embedding = extract_face_embedding(photo_bytes) | |
# Save to DB | |
await save_to_db("visitor", name, age, gender, photo_base64, embedding) | |
return {"message": "Visitor registered successfully"} | |
except Exception as e: | |
print(f"Error during visitor registration: {str(e)}") # Debugging log | |
raise HTTPException(status_code=500, detail=f"Error during visitor registration: {str(e)}") | |
@app.post("/recognize-employee/") | |
async def recognize_employee(photo: UploadFile): | |
try: | |
# Read the uploaded photo | |
photo_bytes = await photo.read() | |
uploaded_embedding = extract_face_embedding(photo_bytes) | |
# Fetch all employee embeddings from the database | |
employees = await db.employee.find_many() | |
if not employees: | |
return {"message": "No employees registered."} | |
# Compare embeddings and find the best match | |
best_match = None | |
highest_similarity = -1 # Start with the lowest similarity | |
for employee in employees: | |
stored_embedding = json.loads(employee.embedding) # Convert JSON string to list | |
similarity = calculate_cosine_similarity(uploaded_embedding, stored_embedding) | |
# Update the best match if the similarity is higher | |
if similarity > highest_similarity: | |
highest_similarity = similarity | |
best_match = employee | |
# Set a threshold for a valid match | |
if best_match and highest_similarity > 0.85: # Example threshold | |
return {"name": best_match.name, "similarity": highest_similarity} | |
else: | |
return {"message": "No match found."} | |
except Exception as e: | |
print(f"Error in employee recognition: {str(e)}") # Debugging log | |
raise HTTPException(status_code=500, detail=f"Error in face recognition: {str(e)}") | |
@app.post("/recognize-visitor/") | |
async def recognize_visitor(photo: UploadFile): | |
try: | |
# Read the uploaded photo | |
photo_bytes = await photo.read() | |
uploaded_embedding = extract_face_embedding(photo_bytes) | |
# Fetch all visitor embeddings from the database | |
visitors = await db.visitor.find_many() | |
if not visitors: | |
return {"message": "No visitors registered."} | |
# Compare embeddings and find the best match | |
best_match = None | |
highest_similarity = -1 # Start with the lowest similarity | |
for visitor in visitors: | |
stored_embedding = json.loads(visitor.embedding) # Convert JSON string to list | |
similarity = calculate_cosine_similarity(uploaded_embedding, stored_embedding) | |
# Update the best match if the similarity is higher | |
if similarity > highest_similarity: | |
highest_similarity = similarity | |
best_match = visitor | |
# Set a threshold for a valid match | |
if best_match and highest_similarity > 0.6: # Example threshold | |
return {"name": best_match.name, "similarity": highest_similarity} | |
else: | |
return {"message": "No match found."} | |
except Exception as e: | |
print(f"Error in visitor recognition: {str(e)}") # Debugging log | |
raise HTTPException(status_code=500, detail=f"Error in face recognition: {str(e)}") | |
if __name__ == "__main__": | |
uvicorn.run(app, host="127.0.0.1", port=8000) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from prisma import Prisma | |
db = Prisma() | |
async def connect_to_db(): | |
try: | |
await db.connect() | |
print("Database connection successful!") | |
except Exception as e: | |
print(f"Failed to connect to the database: {str(e)}") | |
raise | |
async def disconnect_from_db(): | |
try: | |
await db.disconnect() | |
print("Disconnected from the database.") | |
except Exception as e: | |
print(f"Error while disconnecting: {str(e)}") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
services: | |
postgres: | |
image: postgres:15-alpine # Use the official Postgres image | |
container_name: postgres_db # Name of the container | |
ports: | |
- "5432:5432" # Expose the PostgreSQL port (default: 5432) | |
environment: | |
POSTGRES_USER: postgres # Username for PostgreSQL | |
POSTGRES_PASSWORD: postgres # Password for PostgreSQL | |
POSTGRES_DB: my_database # Name of the database to create on initialization | |
volumes: | |
- C:/Users/KaranChourasia/Desktop/deepface/backend/storage:/backend/storage | |
restart: always # Restart the container if it stops | |
# volumes: | |
# pgdata: ./backend/storage/ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment