from fastapi import FastAPI, Response
from pydantic import BaseModel
import easyocr
import requests
import re
import cv2
import numpy as np
import math
from rapidfuzz import fuzz
from pdf2image import convert_from_path
import io
from PIL import Image, ImageOps


app = FastAPI()

# Initialize EasyOCR reader once
reader = easyocr.Reader(['en'])

class InputData(BaseModel):
    input_name: str
    file: str  # URL to Aadhaar/Voter card image or PDF

# --- Helpers ---
def download_file(url: str):
    resp = requests.get(url)
    if resp.status_code != 200:
        return None
    return resp.content

def convert_pdf_to_image(pdf_bytes):
    with open("temp.pdf", "wb") as f:
        f.write(pdf_bytes)
    pages = convert_from_path("temp.pdf", dpi=300)
    img_bytes = io.BytesIO()
    pages[0].save(img_bytes, format="JPEG")
    return img_bytes.getvalue()

def crop_aadhaar_region(img):
    h, w = img.shape[:2]
    # Crop narrower band above DOB line, excluding gender
    # roi = img[int(h*0.28):int(h*0.38), int(w*0.15):int(w*0.85)]
    roi = img[int(h*0.24):int(h*0.42), int(w*0.12):int(w*0.88)]
    cv2.imwrite("debug_roi.jpg", roi)  # Debug: inspect cropped region
    return roi

def detect_card_type(text: str) -> str:
    text_lower = text.lower()
    if "government of india" in text_lower or "vid" in text_lower:
        return "aadhaar"
    if "election commission of india" in text_lower or "elector" in text_lower:
        return "voter"
    return "auto"

def rotate_image(img, angle):
    if angle == 90:
        return cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
    if angle == 180:
        return cv2.rotate(img, cv2.ROTATE_180)
    if angle == 270:
        return cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE)
    return img

def estimate_text_angle(results):
    angles = []
    for r in results:
        box = r[0]
        (x1, y1), (x2, y2) = box[0], box[1]
        angles.append(math.degrees(math.atan2(y2 - y1, x2 - x1)))
    return np.median(angles) if angles else 0

def auto_correct_rotation(img):
    h, w = img.shape[:2]
    max_dim = max(h, w)

    # 🔹 downscale only for orientation detection
    if max_dim > 700:
        scale = 700 / max_dim
        small = cv2.resize(
            img,
            (int(w * scale), int(h * scale)),
            interpolation=cv2.INTER_AREA
        )
    else:
        small = img

    # 🔹 ONE light OCR pass
    results = reader.readtext(small, detail=1, paragraph=False)

    if not results:
        return img  # do nothing if OCR fails

    angle = estimate_text_angle(results)

    # 🔹 snap to nearest 90°
    if -45 <= angle <= 45:
        rotate_by = 0
    elif 45 < angle <= 135:
        rotate_by = 270
    elif -135 <= angle < -45:
        rotate_by = 90
    else:
        rotate_by = 180

    return rotate_image(img, rotate_by)

def fix_exif_orientation(image_bytes):
    try:
        img = Image.open(io.BytesIO(image_bytes))
        img = ImageOps.exif_transpose(img)  # 🔥 THIS IS THE FIX
        return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
    except Exception:
        return None

def rotate_image(img, angle):
    if angle == 90:
        return cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
    if angle == 180:
        return cv2.rotate(img, cv2.ROTATE_180)
    if angle == 270:
        return cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE)
    return img

def looks_like_real_name(text: str) -> bool:
    words = text.split()
    return (
        all(len(w) >= 3 for w in words) and      # no tiny junk words
        sum(1 for c in text if c in "aeiouAEIOU") >= 3  # must have vowels
    )
def deskew_image(img):
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    gray = cv2.bitwise_not(gray)

    thresh = cv2.threshold(
        gray, 0, 255,
        cv2.THRESH_BINARY | cv2.THRESH_OTSU
    )[1]

    coords = np.column_stack(np.where(thresh > 0))
    if len(coords) < 100:
        return img  # nothing to deskew

    angle = cv2.minAreaRect(coords)[-1]

    if angle < -45:
        angle = -(90 + angle)
    else:
        angle = -angle

    if abs(angle) > 30:   # avoid over-rotation
        return img

    (h, w) = img.shape[:2]
    M = cv2.getRotationMatrix2D((w // 2, h // 2), angle, 1.0)
    return cv2.warpAffine(
        img, M, (w, h),
        flags=cv2.INTER_CUBIC,
        borderMode=cv2.BORDER_REPLICATE
    )


# def extract_name_from_image(image_bytes):
#     np_arr = np.frombuffer(image_bytes, np.uint8)
#     base_img = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)

#     for angle in [0, 90, 180, 270]:
#         img = rotate_image(base_img, angle)
#         img = deskew_image(img)
#         # === YOUR EXISTING LOGIC (UNCHANGED) ===
#         results_full = reader.readtext(img, detail=0)
#         text_full = "\n".join(results_full)

#         card_type = detect_card_type(text_full)

#         if card_type == "aadhaar":
#             roi = crop_aadhaar_region(img)
#             results = reader.readtext(roi, detail=0)
#         else:
#             results = results_full

#         text = "\n".join(results)

#         match = re.search(r"Name[^A-Za-z]*([A-Za-z\s]+)", text, re.IGNORECASE)
#         if match:
#             candidate = match.group(1).strip()
#             if candidate.lower() not in ["male", "female", "transgender", "government of india"]:
#                 return candidate

#         match = re.search(r"Elector'?s Name[^A-Za-z]*([A-Za-z\s]+)", text, re.IGNORECASE)
#         if match:
#             return match.group(1).strip()

#         for line in results:
#             line = line.strip()
#             if (re.match(r"^[A-Za-z\s]+$", line)
#                  and len(line.split()) >= 2
#                  and looks_like_real_name(line)
#                 ):
#                if line.lower() not in ["male", "female", "transgender", "government of india"]:
#                 line = re.sub(r"\b(MALE|FEMALE|Transgender)\b", "", line, flags=re.IGNORECASE).strip()
#                 if len(line.split()) >= 2:
#                  return line

#             # if re.match(r"^[A-Za-z\s]+$", line) and len(line.split()) >= 2:
#             #     if line.lower() not in ["male", "female", "transgender", "government of india"]:
#             #         line = re.sub(r"\b(MALE|FEMALE|Transgender)\b", "", line, flags=re.IGNORECASE).strip()
#             #         if len(line.split()) >= 2:
#             #             return line
#         # === END EXISTING LOGIC ===

#     return None

def extract_name_from_image(image_bytes):
    # Step 1: Fix EXIF orientation
    base_img = fix_exif_orientation(image_bytes)
    if base_img is None:
        np_arr = np.frombuffer(image_bytes, np.uint8)
        base_img = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)

    # Step 2: Auto-correct rotation once
    base_img = auto_correct_rotation(base_img)

    # Step 3: Try deskew + brute-force fallback
    for angle in [0, 90, 180, 270]:
        img = rotate_image(base_img, angle)
        img = deskew_image(img)

        # === YOUR EXISTING LOGIC (UNCHANGED) ===
        #results_full = reader.readtext(img, detail=0)
        results_full = [text for (bbox, text, conf) in reader.readtext(img, detail=1) if conf > 0.5]
        text_full = "\n".join(results_full)

        card_type = detect_card_type(text_full)

        # if card_type == "aadhaar":
        #     roi = crop_aadhaar_region(img)
        #     results = reader.readtext(roi, detail=0)
        if card_type == "aadhaar":
         roi = crop_aadhaar_region(img)
         results = reader.readtext(roi, detail=0)
         if len(" ".join(results).split()) < 3:  # too little text
            results = results_full
        else:
            results = results_full

        text = "\n".join(results)

        match = re.search(r"Name[^A-Za-z]*([A-Za-z\s]+)", text, re.IGNORECASE)
        if match:
            candidate = match.group(1).strip()
            if candidate.lower() not in ["male", "female", "transgender", "government of india"]:
                return candidate

        match = re.search(r"Elector'?s Name[^A-Za-z]*([A-Za-z\s]+)", text, re.IGNORECASE)
        if match:
            return match.group(1).strip()

        for line in results:
            line = line.strip()
            if (re.match(r"^[A-Za-z\s]+$", line)
                 and len(line.split()) >= 2
                 and looks_like_real_name(line)):
                if line.lower() not in ["male", "female", "transgender", "government of india"]:
                    line = re.sub(r"\b(MALE|FEMALE|Transgender)\b", "", line, flags=re.IGNORECASE).strip()
                    if len(line.split()) >= 2:
                        return line
        # === END EXISTING LOGIC ===

    return None


def clean_name(name: str) -> str:
    if not name:
        return None
    name = name.replace("\n", " ").strip()
    name = re.sub(r"\s+", " ", name)
    tokens = [t for t in name.split() if re.match(r"^[A-Za-z]+$", t)]
    return " ".join(tokens[:3])

# --- API ---
@app.post("/extract_name")
async def extract_name(data: InputData, response: Response):
    file_bytes = download_file(data.file)
    if not file_bytes:
        response.status_code = 400
        return {
            "input_name": data.input_name,
            "extracted_name": None,
            "status": "error",
            "message": "Could not download file"
        }

    # if data.file.lower().endswith(".pdf"):
    #     file_bytes = convert_pdf_to_image(file_bytes)

    # ✅ Auto-detection, no card_type argument needed
    extracted_raw = extract_name_from_image(file_bytes)
    cleaned_name = clean_name(extracted_raw)

    if not cleaned_name:
        response.status_code = 422
        return {
            "input_name": data.input_name,
            "extracted_name": None,
            "status": "not_detected",
            "message": "Name could not be detected"
        }

    similarity = fuzz.ratio(data.input_name.strip().lower(), cleaned_name.strip().lower())

    if similarity >= 80:
        response.status_code = 200
        status = "matched"
        message = "Matched"
    else:
        response.status_code = 422
        status = "not_matched"
        message = "Not matched"

    return {
        "input_name": data.input_name,
        "extracted_name": cleaned_name,
        "status": status,
        "similarity": similarity,
        "message": message
    }