import os
import cv2
import prediction_utils
from typing import Dict, List
import numpy as np

def test_pipeline(images_directory: str, masks_directory: str, output_directory: str):
    """
    Testuje cały pipeline przetwarzania obrazów krok po kroku:
    1. Filtruje małe obszary klas.
    2. Filtruje dane na podstawie cech (choose_frame_1).
    3. Wyznacza punkty i linie bazowe.
    4. Filtruje dane na podstawie cech linii i punktów (choose_frame_2).
    5. Oblicza kąty alfa.

    Wyniki pośrednie zapisywane są na każdym etapie do odpowiednich katalogów.

    Args:
        images_directory (str): Ścieżka do katalogu z obrazami.
        masks_directory (str): Ścieżka do katalogu z maskami.
        output_directory (str): Ścieżka do katalogu wyjściowego.
    """
    os.makedirs(output_directory, exist_ok=True)
    
    # Wczytanie obrazów i masek
    images = [cv2.imread(os.path.join(images_directory, f)) 
              for f in os.listdir(images_directory) if f.endswith('.png')]
    masks = [cv2.imread(os.path.join(masks_directory, f), 0) 
             for f in os.listdir(masks_directory) if f.endswith('.png')]

    # Wybierz podzbiór danych
    images = images[700:800]
    masks = masks[700:800]

    data = {
        "images": images,
        "masks": masks
    }

    print(f"Initial number of images: {len(data['images'])}")
    print(f"Initial number of masks: {len(data['masks'])}")

    # Krok 1: Filtracja małych klas
    step_1_dir = os.path.join(output_directory, 'step_1_small_class_filter')
    os.makedirs(step_1_dir, exist_ok=True)
    print("1. Filtracja małych klas (`choose_frame_remove_small_areas`)...")
    data = prediction_utils.choose_frame_remove_small_areas(data)
    log_data_statistics(data, "Po filtracji małych klas")
    save_intermediate_results(data, step_1_dir)

    # Krok 2: Filtracja na podstawie cech (`choose_frame_1`)
    step_2_dir = os.path.join(output_directory, 'step_2_feature_filter')
    os.makedirs(step_2_dir, exist_ok=True)
    print("2. Filtracja na podstawie cech (`choose_frame_1`)...")
    data = prediction_utils.choose_frame_1(data)
    log_data_statistics(data, "Po filtracji na podstawie cech")
    save_intermediate_results(data, step_2_dir)

    # Krok 3: Wyznaczanie punktów i linii bazowych
    step_3_dir = os.path.join(output_directory, 'step_3_calculate_points')
    os.makedirs(step_3_dir, exist_ok=True)
    print("3. Wyznaczanie punktów i linii bazowych (`calculate_points_and_baseline_5class`)...")
    data = prediction_utils.calculate_points_and_baseline_5class(data)
    log_data_statistics(data, "Po wyznaczeniu punktów i linii bazowych")
    save_intermediate_results(data, step_3_dir)

    # Krok 4: Filtracja na podstawie punktów i linii (`choose_frame_2`)
    step_4_dir = os.path.join(output_directory, 'step_4_filter_points_and_lines')
    os.makedirs(step_4_dir, exist_ok=True)
    print("4. Filtracja na podstawie punktów i linii (`choose_frame_2`)...")
    data = prediction_utils.choose_frame_2(data)
    log_data_statistics(data, "Po filtracji na podstawie punktów i linii")
    save_intermediate_results(data, step_4_dir)

    # Krok 5: Obliczanie kąta alfa
    step_5_dir = os.path.join(output_directory, 'step_5_calculate_angles')
    os.makedirs(step_5_dir, exist_ok=True)
    print("5. Obliczanie kąta alfa (`identify_alpha_beta_angle_new`)...")
    try:
        result = prediction_utils.identify_alpha_beta_angle_new(data)
        print(f"Największy kąt alfa: {result['alpha']}")
        save_alpha_results(result, step_5_dir)

    except ValueError as e:
        print(f"Błąd podczas obliczania kąta alfa: {e}")

def log_data_statistics(data: Dict[str, List], stage: str):
    """
    Loguje liczbę obrazów i masek w danych po każdym etapie przetwarzania.

    Args:
        data (Dict[str, List]): Dane pośrednie.
        stage (str): Opis etapu przetwarzania.
    """
    num_images = len(data.get('images', []))
    num_masks = len(data.get('masks', []))
    print(f"{stage} - Liczba obrazów: {num_images}, Liczba masek: {num_masks}")

def save_intermediate_results(data: Dict[str, List], output_dir: str):
    """
    Zapisuje obrazy i maski z pośredniego etapu przetwarzania.

    Args:
        data (Dict[str, List]): Dane przetworzone w bieżącym kroku.
        output_dir (str): Katalog, do którego zapisywane są wyniki.
    """
    images_dir = os.path.join(output_dir, 'images')
    masks_dir = os.path.join(output_dir, 'masks')
    os.makedirs(images_dir, exist_ok=True)
    os.makedirs(masks_dir, exist_ok=True)

    for idx, (img, mask) in enumerate(zip(data['images'], data['masks'])):
        cv2.imwrite(os.path.join(images_dir, f'image_{idx}.png'), img)
        cv2.imwrite(os.path.join(masks_dir, f'mask_{idx}.png'), mask)

def save_alpha_results(result: Dict[str, float | np.ndarray], output_dir: str):
    """
    Zapisuje wyniki obliczeń kąta alfa.

    Args:
        result (Dict[str, float | np.ndarray]): Wyniki obliczeń kąta alfa.
        output_dir (str): Katalog, do którego zapisywane są wyniki.
    """
    cv2.imwrite(os.path.join(output_dir, 'image_with_max_alpha.png'), result['image'])
    cv2.imwrite(os.path.join(output_dir, 'mask_with_max_alpha.png'), result['mask'])
    cv2.imwrite(os.path.join(output_dir, 'angle_lines_mask.png'), result['angle_lines_mask'])

# Ścieżki do katalogów z danymi
images_directory = './app/angle_utils_5class/images'
masks_directory = './app/angle_utils_5class/masks'
output_directory = './app/angle_utils_5class/output'

# Uruchomienie testu pipeline
test_pipeline(images_directory, masks_directory, output_directory)