from flask import Blueprint, jsonify, request
import pandas as pd
import os
from datetime import timedelta
from statsmodels.tsa.statespace.sarimax import SARIMAX
from models import SensorData, User, session
from werkzeug.security import generate_password_hash, check_password_hash
from flask_login import login_user, logout_user, login_required, current_user

bp = Blueprint('api', __name__)

# --- AUTHENTICATION & PASSWORD RECOVERY ROUTES (Unchanged) ---

@bp.route('/api/signup', methods=['POST'])
def signup():
    data = request.get_json()
    if session.query(User).filter_by(username=data.get('username')).first():
        return jsonify({"message": "Username already exists."}), 409
    if session.query(User).filter_by(email=data.get('email')).first():
        return jsonify({"message": "Email already registered."}), 409
        
    new_user = User(
        username=data.get('username'), 
        email=data.get('email'), 
        security_question_1=data.get('security_question_1'), 
        security_answer_1_hash=generate_password_hash(data.get('security_answer_1')), 
        security_question_2=data.get('security_question_2'), 
        security_answer_2_hash=generate_password_hash(data.get('security_answer_2'))
    )
    new_user.set_password(data.get('password'))
    session.add(new_user)
    session.commit()
    return jsonify({"message": "Account created successfully."}), 201

@bp.route('/api/login', methods=['POST'])
def login():
    data = request.get_json()
    user = session.query(User).filter_by(username=data.get('username')).first()
    if not user or not user.check_password(data.get('password')):
        return jsonify({"message": "Invalid username or password."}), 401
    login_user(user)
    return jsonify({"message": "Logged in successfully."}), 200

@bp.route('/api/logout', methods=['POST'])
@login_required
def logout():
    logout_user()
    return jsonify({"message": "You have been logged out."}), 200

@bp.route('/api/check_session', methods=['GET'])
def check_session():
    if current_user.is_authenticated:
        return jsonify({"loggedIn": True, "username": current_user.username}), 200
    return jsonify({"loggedIn": False}), 401

@bp.route('/api/get-security-questions', methods=['POST'])
def get_security_questions():
    data = request.get_json()
    user = session.query(User).filter_by(username=data.get('username')).first()
    if not user or not user.security_question_1:
        return jsonify({"message": "Unable to retrieve security questions for this user."}), 404
    return jsonify({"question1": user.security_question_1, "question2": user.security_question_2}), 200

@bp.route('/api/reset-password', methods=['POST'])
def reset_password():
    data = request.get_json()
    user = session.query(User).filter_by(username=data.get('username')).first()
    if not user:
        return jsonify({"message": "Invalid credentials."}), 401
    
    answer1_correct = check_password_hash(user.security_answer_1_hash, data.get('answer1'))
    answer2_correct = check_password_hash(user.security_answer_2_hash, data.get('answer2'))

    if not (answer1_correct and answer2_correct):
        return jsonify({"message": "One or more security answers are incorrect."}), 401
    
    user.set_password(data.get('newPassword'))
    session.commit()
    return jsonify({"message": "Password has been reset successfully. Please log in."}), 200


# --- DATA ROUTES (Unchanged) ---
@bp.route('/status', methods=['GET'])
def status():
    return jsonify({'status': 'API is running'})

@bp.route('/api/upload-mock-data', methods=['POST'])
def upload_mock_data():
    file_path = "mock_sensor_data.csv"
    if not os.path.exists(file_path):
        return jsonify({"error": "Mock file not found"}), 404
    df = pd.read_csv(file_path)
    df["timestamp"] = pd.to_datetime(df["timestamp"])
    for _, row in df.iterrows():
        entry = SensorData(
            timestamp=row["timestamp"],
            sensor_id=row["sensor_id"],
            rainfall_mm=row["rainfall_mm"],
            water_level_cm=row["water_level_cm"],
            flow_rate_lps=row["flow_rate_lps"]
        )
        session.add(entry)
    session.commit()
    return jsonify({"message": "Mock data uploaded successfully!"})

@bp.route('/api/get-sensor-data', methods=['GET'])
@login_required
def get_sensor_data():
    try:
        data = session.query(SensorData).all()
        result = [
            {
                "timestamp": entry.timestamp.strftime("%Y-%m-%d %H:%M:%S"), 
                "sensor_id": entry.sensor_id, 
                "rainfall_mm": entry.rainfall_mm, 
                "water_level_cm": entry.water_level_cm, 
                "flow_rate_lps": entry.flow_rate_lps
            } for entry in data
        ]
        return jsonify(result)
    except Exception as e:
        return jsonify({"error": f"An internal error occurred: {str(e)}"}), 500

# --- UPDATED FORECASTING ROUTE ---

@bp.route('/api/forecast', methods=['GET'])
@login_required
def forecast():
    try:
        query = session.query(SensorData).statement
        df = pd.read_sql(query, session.bind)

        if df.empty or len(df) < 24: # Need enough data to forecast
            return jsonify({"error": "Not enough data to create a forecast."}), 400

        df['timestamp'] = pd.to_datetime(df['timestamp'])
        df = df.set_index('timestamp').sort_index()
        
        forecast_results = {}
        metrics_to_forecast = {
            "waterLevel": "water_level_cm",
            "flowRate": "flow_rate_lps",
            "rainfall": "rainfall_mm"
        }
        
        last_timestamp = df.index[-1]
        forecast_steps = 24
        future_dates = pd.date_range(start=last_timestamp + timedelta(hours=1), periods=forecast_steps, freq='H')
        
        for key, column_name in metrics_to_forecast.items():
            # Resample series for stability
            series = df[column_name].resample('H').mean().fillna(method='ffill')
            
            # Simple check to avoid trying to forecast flat-line data (common with rainfall)
            if series.nunique() < 2:
                predicted_mean = [series.iloc[-1]] * forecast_steps
            else:
                # Define and train the SARIMAX model
                model = SARIMAX(series, order=(1, 1, 1), seasonal_order=(1, 1, 0, 12), enforce_stationarity=False, enforce_invertibility=False)
                results = model.fit(disp=False)
                prediction = results.get_forecast(steps=forecast_steps)
                predicted_mean = prediction.predicted_mean.tolist()

            # Format the prediction for this metric
            forecast_results[key] = {
                "timestamps": [d.strftime("%Y-%m-%d %H:%M:%S") for d in future_dates],
                "predicted_values": predicted_mean
            }
            # Ensure rainfall forecast does not predict negative values
            if key == 'rainfall':
                forecast_results[key]['predicted_values'] = [max(0, val) for val in forecast_results[key]['predicted_values']]

        return jsonify(forecast_results)
        
    except Exception as e:
        return jsonify({"error": f"Failed to generate forecast: {str(e)}"}), 500