import torch
import timm
from torchvision import transforms
from PIL import Image
import os

# Paths
MODEL_PATH = "obstacle_classifier_timm.pt"
IMAGE_PATH = "data_obstacle/no_obstacle/1759127899416-rn_image_picker_lib_temp_deb3c46e-f7f1-468c-a8ee-e184dab1c3ec.jpg"  # Replace with your actual image path

# Load model architecture and weights
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = timm.create_model('efficientnet_b0', pretrained=False, num_classes=1)
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model = model.to(device)
model.eval()

# Transform (must match training)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Load and preprocess image
img = Image.open(IMAGE_PATH).convert("RGB")
x = transform(img).unsqueeze(0).to(device)

# Predict
with torch.no_grad():
    out = model(x)
    prob = torch.sigmoid(out).item()
    pred = 1 if prob > 0.5 else 0
    label = "obstacle_present" if pred == 1 else "no_obstacle"
    print(f"🧪 Prediction: {label} (Confidence: {prob:.4f})")