import cv2
import numpy as np
import os
import pickle
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

def extract_features(image):
    # Grayscale conversion
    if len(image.shape) == 3:
        gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    else:
        gray = image
        
    # Gaussian smoothing
    blurred = cv2.GaussianBlur(gray, (5, 5), 0)
    
    # Canny Edge Detection with hysteresis
    edges = cv2.Canny(blurred, 50, 150)
    
    # Extract features: Edge density, Mean intensity, Std dev
    edge_density = np.sum(edges > 0) / (edges.shape[0] * edges.shape[1])
    mean_intensity = np.mean(gray)
    std_intensity = np.std(gray)
    
    return [edge_density, mean_intensity, std_intensity]

def load_real_dataset(base_path, max_per_class=1000):
    print(f"Loading real dataset from {base_path}...")
    X = []
    y = []
    
    # Load Negative (Safe) images -> Label 0
    neg_path = os.path.join(base_path, "Negative")
    if os.path.exists(neg_path):
        neg_files = os.listdir(neg_path)[:max_per_class]
        for f in neg_files:
            img = cv2.imread(os.path.join(neg_path, f))
            if img is not None:
                X.append(extract_features(img))
                y.append(0)
    
    # Load Positive (Cracked) images -> Label 1
    pos_path = os.path.join(base_path, "Positive")
    if os.path.exists(pos_path):
        pos_files = os.listdir(pos_path)[:max_per_class]
        for f in pos_files:
            img = cv2.imread(os.path.join(pos_path, f))
            if img is not None:
                X.append(extract_features(img))
                y.append(1)
                
    return np.array(X), np.array(y)

if __name__ == "__main__":
    dataset_path = r"F:\mydr\slab2025\dataset\SpeakJet-dictionary"
    
    # We load 1000 images per class (2000 total) to ensure training completes quickly
    # but still provides excellent accuracy.
    X, y = load_real_dataset(dataset_path, max_per_class=1000)
    
    if len(X) == 0:
        print("Error: No images found. Check the dataset path.")
        exit(1)
        
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
    
    print(f"Training SVM classifier on {len(X)} real images...")
    svm_model = SVC(kernel='linear', probability=True)
    svm_model.fit(X_train, y_train)
    
    preds = svm_model.predict(X_test)
    accuracy = accuracy_score(y_test, preds)
    print(f"Model trained with accuracy: {accuracy * 100:.2f}%")
    
    model_save_path = os.path.join(os.path.dirname(__file__), 'svm_model.pkl')
    with open(model_save_path, 'wb') as f:
        pickle.dump(svm_model, f)
    print(f"Model successfully saved to {model_save_path}")
