import pandas as pd
import numpy as np
from typing import List, Dict, Tuple, Optional
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.preprocessing import normalize
import google.generativeai as genai
import logging
from functools import lru_cache
import re


logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)



class TextEmbedder:
    def __init__(self, api_key: str, model_name: str = "models/text-embedding-004", batch_size: int = 50):
        genai.configure(api_key='AIzaSyCoNC4SCFhrO8QvD34a9KMqyNQ-mudMtQ4')
        self.model = model_name
        self.batch_size = batch_size
        self.embedding_dim = self._get_model_dimension()
        logger.info(f"Initialized with embedding dimension: {self.embedding_dim}")

    def _get_model_dimension(self) -> int:
       
        try:
            test_embedding = genai.embed_content(
                model=self.model,
                content="dimension test",
                task_type="RETRIEVAL_DOCUMENT"
            )['embedding']
            return len(test_embedding)
        except Exception as e:
            logger.error(f"Failed to get model dimension: {str(e)}")
            logger.info("Defaulting to 768 dimensions")
            return 768 

    def _preprocess_text(self, text: str) -> str:
      
        text = text.lower().strip()
        text = re.sub(r'\s+', ' ', text) 
        text = re.sub(r'[^\w\s|:-]', '', text) 
        return text

    def _combine_text_features(self, row: pd.Series, text_columns: List[str]) -> str:
       
        features = []
        for col in text_columns:
            if col in row and pd.notna(row[col]):
                value = str(row[col])
                if "|" in value: 
                    value = value.replace("|", ",")
                features.append(f"{col}:{value}")
        return self._preprocess_text(" | ".join(features))

    def get_brand_text_features(self, brand: pd.Series) -> str:
        text_columns = [
            'industry', 'target_audience', 'brand_messaging',
            'tone_voice', 'category_alignment', 
            'brand_alignment_keywords', 'content_type'
        ]
        return self._combine_text_features(brand, text_columns)

    def get_influencer_text_features(self, influencer: pd.Series) -> str:
        text_columns = [
            'category_niche', 'audience_demographics',
            'audience_interests', 'content_types'
        ]
        return self._combine_text_features(influencer, text_columns)

    @lru_cache(maxsize=5000)
    def get_embedding(self, text: str) -> np.ndarray:
      
        if not text.strip():
            return np.zeros(self.embedding_dim)
            
        try:
            result = genai.embed_content(
                model=self.model,
                content=text,
                task_type="RETRIEVAL_DOCUMENT"
            )
            embedding = np.array(result['embedding'])
            return normalize(embedding.reshape(1, -1))[0]
        except Exception as e:
            logger.error(f"Embedding error: {str(e)} | Text: {text[:100]}...")
            return np.zeros(self.embedding_dim)

    def batch_get_embeddings(self, texts: List[str]) -> np.ndarray:
       
        embeddings = []
        for i in range(0, len(texts), self.batch_size):
            batch = texts[i:i+self.batch_size]
            try:
                response = genai.batch_embed_texts(
                    model=self.model,
                    texts=batch
                )
                batch_embeddings = [np.array(e['embedding']) for e in response]
                embeddings.extend(batch_embeddings)
            except Exception as e:
                logger.error(f"Batch embedding failed: {str(e)}")
                embeddings.extend([np.zeros(self.embedding_dim)]*len(batch))
        return normalize(np.array(embeddings))

    def calculate_text_similarity(self, brand_text: str, influencer_text: str) -> float:
       
        brand_embedding = self.get_embedding(brand_text)
        influencer_embedding = self.get_embedding(influencer_text)
        
        similarity = cosine_similarity(
            brand_embedding.reshape(1, -1),
            influencer_embedding.reshape(1, -1)
        )[0][0]
        
        return float(np.clip(similarity, 0, 1)) 

    def get_similarity_matrix(self, brands_df: pd.DataFrame, influencers_df: pd.DataFrame) -> np.ndarray:
      
        brand_texts = [self.get_brand_text_features(row) for _, row in brands_df.iterrows()]
        influencer_texts = [self.get_influencer_text_features(row) for _, row in influencers_df.iterrows()]
        
       
        brand_embeddings = self.batch_get_embeddings(brand_texts)
        influencer_embeddings = self.batch_get_embeddings(influencer_texts)
        
      
        similarity_matrix = cosine_similarity(brand_embeddings, influencer_embeddings)
        return np.clip(similarity_matrix, 0, 1)

    def analyze_feature_alignment(self, brand_text: str, influencer_text: str) -> Dict:
       
        brand_features = set(brand_text.split(" | "))
        influencer_features = set(influencer_text.split(" | "))
        
        common_features = brand_features & influencer_features
        unique_brand = brand_features - influencer_features
        unique_influencer = influencer_features - brand_features
        
        return {
            'common_features': list(common_features),
            'unique_brand_features': list(unique_brand),
            'unique_influencer_features': list(unique_influencer),
            'feature_overlap_ratio': len(common_features) / max(len(brand_features), 1)
        }

    def print_detailed_match_analysis(self, brand: pd.Series, influencer: pd.Series, similarity_score: float):
       
        brand_text = self.get_brand_text_features(brand)
        influencer_text = self.get_influencer_text_features(influencer)
        alignment = self.analyze_feature_alignment(brand_text, influencer_text)
        
        print("\n" + "="*80)
        print(f"Match Analysis - Brand: {brand.get('name', 'Unknown')} vs Influencer: {influencer.get('name', 'Unknown')}")
        print("-"*80)
        
        print("\nFeature Alignment:")
        print(f"Common Features ({len(alignment['common_features'])}):")
        for feat in alignment['common_features'][:5]:
            print(f"  - {feat}")
        
        print(f"\nBrand Unique Features ({len(alignment['unique_brand_features'])}):")
        for feat in alignment['unique_brand_features'][:3]:
            print(f"  - {feat}")
        
        print(f"\nInfluencer Unique Features ({len(alignment['unique_influencer_features'])}):")
        for feat in alignment['unique_influencer_features'][:3]:
            print(f"  - {feat}")
        
        print("\n" + "-"*80)
        print(f"Text Similarity Score: {similarity_score:.4f}")
        print("Score Interpretation:")
        self._print_score_interpretation(similarity_score)
        print("="*80)

    def _print_score_interpretation(self, score: float):
     
        thresholds = [
            (0.9, "Exceptional Match", "Near-perfect alignment in brand/influencer characteristics"),
            (0.7, "Strong Match", "High potential for successful collaboration"),
            (0.5, "Moderate Match", "Potential with some adjustments needed"),
            (0.3, "Weak Match", "Limited alignment - consider carefully"),
            (0.0, "Poor Match", "Unlikely to be a good fit")
        ]
        
        for threshold, title, description in thresholds:
            if score >= threshold:
                print(f"{title} (≥{threshold:.1f}): {description}")
                return
                
    def save_embeddings(self, df: pd.DataFrame, output_path: str, entity_type: str = "brand"):
      
        texts = []
        for _, row in df.iterrows():
            if entity_type == "brand":
                texts.append(self.get_brand_text_features(row))
            else:
                texts.append(self.get_influencer_text_features(row))
        
        embeddings = self.batch_get_embeddings(texts)
        np.save(output_path, embeddings)
        logger.info(f"Saved {entity_type} embeddings to {output_path}")

    def load_embeddings(self, input_path: str) -> np.ndarray:
       
        return np.load(input_path)