import numpy as np
import pandas as pd
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.preprocessing import MinMaxScaler
from typing import Dict, List, Tuple, Set
import logging
from ..models.text_embedder import TextEmbedder
from ..database.db_connector import DatabaseConnector
logger = logging.getLogger(__name__)
class SimilarityScorer:
def __init__(self, config: Dict):
self.config = config
self.similarity_weights = config['similarity_weights']
self.related_categories = {
k: set(v) for k, v in config['category_relationships'].items()
}
self.related_audiences = {
k: set(v) for k, v in config['audience_relationships'].items()
}
self.scaler = MinMaxScaler()
# Initialize the text embedder
self.text_embedder = TextEmbedder(
gemini_api_key=config['text_embedding'].get('gemini_api_key'),
pinecone_config={
'api_key': config.get('pinecone', {}).get('api_key', ''),
'index_name': config.get('pinecone', {}).get('index_name', 'recommendationsystempro'),
'namespace': config.get('pinecone', {}).get('namespace', 'influencer-matching')
}
)
# Initialize database connector if database config exists
self.db_connector = None
if 'database' in self.config:
try:
self.db_connector = DatabaseConnector(self.config)
except Exception as e:
logger.warning(f"Could not initialize database connection: {str(e)}")
def _get_related_categories(self, category: str) -> Set[str]:
category = category.lower()
for main_cat, related in self.related_categories.items():
if category in related or category == main_cat:
return related | {main_cat}
return set()
def _calculate_category_similarity_embedding(self, brand: pd.Series, influencer: pd.Series) -> float:
try:
# Extract category-related information
brand_industry = str(brand.get('industry', '')).lower()
brand_alignment = str(brand.get('category_alignment', '')).lower()
influencer_niche = str(influencer.get('category_niche', '')).lower()
# Combine the category data with descriptive context
brand_category_text = f"Brand industry: {brand_industry}. Brand category alignment: {brand_alignment}"
influencer_category_text = f"Influencer category/niche: {influencer_niche}"
# Use the text embedder to get embedding vectors
brand_embedding = self.text_embedder.get_embedding(brand_category_text)
influencer_embedding = self.text_embedder.get_embedding(influencer_category_text)
# Calculate cosine similarity between the embedding vectors
similarity = cosine_similarity(
brand_embedding.reshape(1, -1),
influencer_embedding.reshape(1, -1)
)[0][0]
# Apply a power transformation to enhance differentiation between scores
# This gives more weight to higher similarities
adjusted_similarity = similarity ** 0.7
logger.info(f"Embedding-based category similarity score: {adjusted_similarity:.2f} for {brand_industry}/{brand_alignment} -> {influencer_niche}")
return float(adjusted_similarity)
except Exception as e:
logger.warning(f"Error using embeddings for category similarity: {str(e)}, falling back to rule-based method")
return self._calculate_category_similarity_rule_based(brand, influencer)
def _calculate_category_similarity_rule_based(self, brand: pd.Series, influencer: pd.Series) -> float:
brand_categories = set(str(brand.get('industry', '')).lower().split('/'))
brand_alignment = set(str(brand.get('category_alignment', '')).lower().split('/'))
influencer_categories = set(str(influencer.get('category_niche', '')).lower().split('/'))
expanded_brand_cats = set()
for cat in brand_categories | brand_alignment:
expanded_brand_cats.update(self._get_related_categories(cat))
expanded_influencer_cats = set()
for cat in influencer_categories:
expanded_influencer_cats.update(self._get_related_categories(cat))
direct_matches = len(brand_categories.intersection(influencer_categories))
alignment_matches = len(brand_alignment.intersection(influencer_categories))
related_matches = len(expanded_brand_cats.intersection(expanded_influencer_cats))
score = (
direct_matches * 0.6 +
alignment_matches * 0.3 +
related_matches * 0.1
) / max(len(brand_categories), 1)
if direct_matches == 0 and alignment_matches == 0:
score *= 0.2
return score
def _calculate_category_similarity(self, brand: pd.Series, influencer: pd.Series) -> float:
# Try the embedding-based approach first, fallback to rule-based if it fails
return self._calculate_category_similarity_embedding(brand, influencer)
def _calculate_audience_similarity(self, brand: pd.Series, influencer: pd.Series) -> float:
brand_audience = str(brand.get('target_audience', '')).lower()
influencer_audience = str(influencer.get('audience_demographics', '')).lower()
demographic_match = float(brand_audience in influencer_audience or
influencer_audience in brand_audience)
related_match = 0.0
for main_audience, related in self.related_audiences.items():
if (brand_audience in {a.lower() for a in related | {main_audience}} and
influencer_audience in {a.lower() for a in related | {main_audience}}):
related_match = 0.7
break
brand_geo = str(brand.get('geographic_target', '')).lower()
influencer_loc = str(influencer.get('location', '')).lower()
geo_match = float(
brand_geo in influencer_loc or
influencer_loc in brand_geo or
brand_geo == 'global' or
(brand_geo == 'north america' and influencer_loc in ['usa', 'canada'])
)
brand_lang = set(str(brand.get('language_preferences', '')).lower().split('/'))
influencer_lang = set(str(influencer.get('languages', '')).lower().split('/'))
lang_match = len(brand_lang.intersection(influencer_lang)) / max(len(brand_lang), 1)
audience_score = max(demographic_match, related_match) * 0.5 + geo_match * 0.3 + lang_match * 0.2
return audience_score
def _safe_float(self, value, default=0.0) -> float:
try:
result = float(value)
return result if result != 0 else default
except (ValueError, TypeError):
return default
def _safe_division(self, numerator, denominator, default=0.0) -> float:
num = self._safe_float(numerator)
den = self._safe_float(denominator)
if den == 0:
return default
return num / den
def _calculate_numerical_similarity(self, brand: pd.Series, influencer: pd.Series) -> float:
scores = []
min_followers = self._safe_float(brand.get('min_follower_range'), 1.0)
actual_followers = self._safe_float(influencer.get('follower_count'), 0.0)
if actual_followers < min_followers:
return 0.0
follower_ratio = self._safe_division(actual_followers, min_followers, 0.0)
scores.append(min(follower_ratio, 2.0))
min_engagement = self._safe_float(brand.get('min_engagement_rate'), 0.01)
actual_engagement = self._safe_float(influencer.get('engagement_rate'), 0.0)
if actual_engagement < min_engagement:
return 0.0
engagement_ratio = self._safe_division(actual_engagement, min_engagement, 0.0)
scores.append(min(engagement_ratio, 2.0))
posts_per_campaign = self.config['matching']['posts_per_campaign']
campaign_budget = self._safe_float(brand.get('campaign_budget'), 0.0)
cost_per_post = self._safe_float(influencer.get('cost_per_post'), float('inf'))
if cost_per_post * posts_per_campaign > campaign_budget:
return 0.0
if campaign_budget > 0 and cost_per_post < float('inf'):
budget_ratio = campaign_budget / (cost_per_post * posts_per_campaign)
scores.append(min(budget_ratio, 2.0))
if not scores:
return 0.0
average_score = np.mean(scores)
return min(average_score, 1.0)
def _calculate_compliance_similarity(self, brand: pd.Series, influencer: pd.Series) -> float:
requires_controversy_free = brand.get('requires_controversy_free', False)
controversy_flag = influencer.get('controversy_flag', True)
compliance_status = str(influencer.get('compliance_status', '')).lower()
if requires_controversy_free and controversy_flag:
return 0.0
controversy_match = not (requires_controversy_free and controversy_flag)
compliance_match = compliance_status == 'verified'
return (float(controversy_match) + float(compliance_match)) / 2
def calculate_similarity_matrix(self, brands_features: pd.DataFrame,
influencers_features: pd.DataFrame) -> np.ndarray:
similarity_matrix = np.zeros((len(brands_features), len(influencers_features)))
text_similarity_matrix = np.zeros((len(brands_features), len(influencers_features)))
for i, brand in brands_features.iterrows():
brand_text = self.text_embedder.get_brand_text_features(brand)
for j, influencer in influencers_features.iterrows():
influencer_text = self.text_embedder.get_influencer_text_features(influencer)
text_similarity = self.text_embedder.calculate_text_similarity(brand_text, influencer_text)
text_similarity_matrix[brands_features.index.get_loc(i),
influencers_features.index.get_loc(j)] = text_similarity
for i, brand in brands_features.iterrows():
for j, influencer in influencers_features.iterrows():
category_score = self._calculate_category_similarity(brand, influencer)
audience_score = self._calculate_audience_similarity(brand, influencer)
numerical_score = self._calculate_numerical_similarity(brand, influencer)
compliance_score = self._calculate_compliance_similarity(brand, influencer)
traditional_score = (
category_score * self.similarity_weights['category'] +
audience_score * self.similarity_weights['audience'] +
numerical_score * self.similarity_weights['numerical'] +
compliance_score * self.similarity_weights['compliance']
)
if numerical_score == 0.0:
traditional_score = 0.0
elif category_score < 0.3:
traditional_score *= 0.5
text_score = text_similarity_matrix[brands_features.index.get_loc(i),
influencers_features.index.get_loc(j)]
final_score = 0.5 * traditional_score + 0.5 * text_score
similarity_matrix[brands_features.index.get_loc(i),
influencers_features.index.get_loc(j)] = final_score
max_score = similarity_matrix.max()
if max_score > 0:
similarity_matrix = similarity_matrix / max_score
similarity_matrix = np.where(similarity_matrix > 0.95, 0.95, similarity_matrix)
return similarity_matrix
def get_top_matches(self, similarity_matrix: np.ndarray,
brands_df: pd.DataFrame,
influencers_df: pd.DataFrame) -> List[Tuple[str, str, float]]:
matches = []
top_n = self.config['matching']['top_n']
min_similarity = self.config['matching']['similarity_threshold']
for i, brand in brands_df.iterrows():
brand_matches = []
for j, influencer in influencers_df.iterrows():
category_score = self._calculate_category_similarity(brand, influencer)
audience_score = self._calculate_audience_similarity(brand, influencer)
numerical_score = self._calculate_numerical_similarity(brand, influencer)
compliance_score = self._calculate_compliance_similarity(brand, influencer)
traditional_score = (
category_score * self.similarity_weights['category'] +
audience_score * self.similarity_weights['audience'] +
numerical_score * self.similarity_weights['numerical'] +
compliance_score * self.similarity_weights['compliance']
)
brand_text = self.text_embedder.get_brand_text_features(brand)
influencer_text = self.text_embedder.get_influencer_text_features(influencer)
text_score = self.text_embedder.calculate_text_similarity(brand_text, influencer_text)
final_score = 0.5 * traditional_score + 0.5 * text_score
if numerical_score == 0.0:
final_score = 0.0
elif category_score < self.config['matching']['min_category_score']:
final_score *= self.config['matching']['category_penalty']
if final_score >= min_similarity:
brand_matches.append((
brand.name,
influencer.name,
round(final_score, 3)
))
brand_matches.sort(key=lambda x: x[2], reverse=True)
matches.extend(brand_matches[:top_n])
return matches
def save_matches_to_database(self, matches: List[Tuple[str, str, float]]) -> bool:
if not self.db_connector:
logger.error("Database connector not available. Cannot save matches.")
return False
try:
match_data = []
for brand_id, influencer_id, score in matches:
match_data.append({
'brand_id': brand_id,
'influencer_id': influencer_id,
'similarity_score': score
})
self.db_connector.execute_query("""
CREATE TABLE IF NOT EXISTS matches (
id INT AUTO_INCREMENT PRIMARY KEY,
brand_id VARCHAR(50),
influencer_id VARCHAR(50),
similarity_score FLOAT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
""")
self.db_connector.insert_matches(match_data)
logger.info(f"Saved {len(matches)} matches to database")
return True
except Exception as e:
logger.error(f"Error saving matches to database: {str(e)}")
return False
Preview:
downloadDownload PNG
downloadDownload JPEG
downloadDownload SVG
Tip: You can change the style, width & colours of the snippet with the inspect tool before clicking Download!
Click to optimize width for Twitter