"""
Observability: Metrics, tracing, and structured logging
"""
import time
import logging
import structlog
from typing import Optional, Dict, Any, Callable
from functools import wraps
from datetime import datetime
from prometheus_client import Counter, Histogram, Gauge, CollectorRegistry, generate_latest
from contextlib import contextmanager

from config import settings

# Configure structured logging
structlog.configure(
    processors=[
        structlog.stdlib.filter_by_level,
        structlog.stdlib.add_logger_name,
        structlog.stdlib.add_log_level,
        structlog.stdlib.PositionalArgumentsFormatter(),
        structlog.processors.TimeStamper(fmt="iso"),
        structlog.processors.StackInfoRenderer(),
        structlog.processors.format_exc_info,
        structlog.processors.UnicodeDecoder(),
        structlog.processors.JSONRenderer()
    ],
    context_class=dict,
    logger_factory=structlog.stdlib.LoggerFactory(),
    cache_logger_on_first_use=True,
)

logger = structlog.get_logger()


# Prometheus metrics
registry = CollectorRegistry()

# Job metrics
job_requests_total = Counter(
    'ai_job_requests_total',
    'Total number of AI job requests',
    ['job_type', 'tenant_id'],
    registry=registry
)

job_duration_seconds = Histogram(
    'ai_job_duration_seconds',
    'AI job processing duration',
    ['job_type', 'status'],
    registry=registry
)

job_status_total = Counter(
    'ai_job_status_total',
    'Total jobs by status',
    ['job_type', 'status'],
    registry=registry
)

# LLM metrics
llm_requests_total = Counter(
    'llm_requests_total',
    'Total LLM API requests',
    ['provider', 'model'],
    registry=registry
)

llm_tokens_total = Counter(
    'llm_tokens_total',
    'Total tokens used',
    ['provider', 'model', 'type'],  # type: input/output
    registry=registry
)

llm_cost_usd_total = Counter(
    'llm_cost_usd_total',
    'Total cost in USD',
    ['provider', 'model'],
    registry=registry
)

llm_latency_seconds = Histogram(
    'llm_latency_seconds',
    'LLM API latency',
    ['provider', 'model'],
    registry=registry
)

# Publish metrics
publish_requests_total = Counter(
    'publish_requests_total',
    'Total publish requests',
    ['platform', 'status'],
    registry=registry
)

publish_duration_seconds = Histogram(
    'publish_duration_seconds',
    'Publish operation duration',
    ['platform'],
    registry=registry
)

# System metrics
active_workers = Gauge(
    'celery_active_workers',
    'Number of active Celery workers',
    registry=registry
)

queue_length = Gauge(
    'celery_queue_length',
    'Number of tasks in queue',
    ['queue_name'],
    registry=registry
)


class MetricsCollector:
    """Collect and track metrics"""
    
    @staticmethod
    def track_job_request(job_type: str, tenant_id: str):
        """Track job request"""
        if settings.enable_metrics:
            job_requests_total.labels(
                job_type=job_type,
                tenant_id=tenant_id
            ).inc()
    
    @staticmethod
    def track_job_completion(job_type: str, status: str, duration: float):
        """Track job completion"""
        if settings.enable_metrics:
            job_status_total.labels(
                job_type=job_type,
                status=status
            ).inc()
            
            job_duration_seconds.labels(
                job_type=job_type,
                status=status
            ).observe(duration)
    
    @staticmethod
    def track_llm_request(provider: str, model: str, tokens_in: int, 
                         tokens_out: int, cost: float, latency: float):
        """Track LLM API request"""
        if settings.enable_metrics:
            llm_requests_total.labels(
                provider=provider,
                model=model
            ).inc()
            
            llm_tokens_total.labels(
                provider=provider,
                model=model,
                type='input'
            ).inc(tokens_in)
            
            llm_tokens_total.labels(
                provider=provider,
                model=model,
                type='output'
            ).inc(tokens_out)
            
            llm_cost_usd_total.labels(
                provider=provider,
                model=model
            ).inc(cost)
            
            llm_latency_seconds.labels(
                provider=provider,
                model=model
            ).observe(latency)
    
    @staticmethod
    def track_publish_request(platform: str, status: str, duration: float):
        """Track publish request"""
        if settings.enable_metrics:
            publish_requests_total.labels(
                platform=platform,
                status=status
            ).inc()
            
            publish_duration_seconds.labels(
                platform=platform
            ).observe(duration)
    
    @staticmethod
    def update_worker_count(count: int):
        """Update active worker count"""
        if settings.enable_metrics:
            active_workers.set(count)
    
    @staticmethod
    def update_queue_length(queue_name: str, length: int):
        """Update queue length"""
        if settings.enable_metrics:
            queue_length.labels(queue_name=queue_name).set(length)
    
    @staticmethod
    def get_metrics() -> bytes:
        """Get Prometheus metrics"""
        return generate_latest(registry)


class CostTracker:
    """Track costs for AI operations"""
    
    # Pricing per 1K tokens (approximate, update as needed)
    PRICING = {
        'openai': {
            'gpt-4-turbo-preview': {'input': 0.01, 'output': 0.03},
            'gpt-4': {'input': 0.03, 'output': 0.06},
            'gpt-3.5-turbo': {'input': 0.0005, 'output': 0.0015},
            'text-embedding-3-small': {'input': 0.00002, 'output': 0},
            'text-embedding-3-large': {'input': 0.00013, 'output': 0},
        },
        'anthropic': {
            'claude-3-opus-20240229': {'input': 0.015, 'output': 0.075},
            'claude-3-sonnet-20240229': {'input': 0.003, 'output': 0.015},
            'claude-3-haiku-20240307': {'input': 0.00025, 'output': 0.00125},
        },
        'local': {
            'default': {'input': 0, 'output': 0},
        }
    }
    
    @classmethod
    def calculate_cost(cls, provider: str, model: str, 
                      tokens_input: int, tokens_output: int) -> float:
        """Calculate cost for token usage"""
        if not settings.track_costs:
            return 0.0
        
        provider_pricing = cls.PRICING.get(provider, {})
        model_pricing = provider_pricing.get(model, {'input': 0, 'output': 0})
        
        cost_input = (tokens_input / 1000) * model_pricing['input']
        cost_output = (tokens_output / 1000) * model_pricing['output']
        
        return round(cost_input + cost_output, 6)
    
    @classmethod
    def estimate_cost(cls, provider: str, model: str, 
                     estimated_tokens: int) -> float:
        """Estimate cost for operation"""
        # Assume 70% input, 30% output for estimation
        tokens_in = int(estimated_tokens * 0.7)
        tokens_out = int(estimated_tokens * 0.3)
        return cls.calculate_cost(provider, model, tokens_in, tokens_out)


@contextmanager
def track_operation(operation_name: str, **context):
    """Context manager to track operation timing and log"""
    start_time = time.time()
    log = logger.bind(operation=operation_name, **context)
    
    try:
        log.info(f"{operation_name}_started")
        yield log
        duration = time.time() - start_time
        log.info(f"{operation_name}_completed", duration=duration)
    except Exception as e:
        duration = time.time() - start_time
        log.error(f"{operation_name}_failed", duration=duration, error=str(e))
        raise


def track_async_operation(operation_name: str):
    """Decorator to track async operations"""
    def decorator(func: Callable):
        @wraps(func)
        async def wrapper(*args, **kwargs):
            start_time = time.time()
            log = logger.bind(operation=operation_name, function=func.__name__)
            
            try:
                log.info(f"{operation_name}_started")
                result = await func(*args, **kwargs)
                duration = time.time() - start_time
                log.info(f"{operation_name}_completed", duration=duration)
                return result
            except Exception as e:
                duration = time.time() - start_time
                log.error(
                    f"{operation_name}_failed",
                    duration=duration,
                    error=str(e),
                    error_type=type(e).__name__
                )
                raise
        
        return wrapper
    return decorator


def track_sync_operation(operation_name: str):
    """Decorator to track sync operations"""
    def decorator(func: Callable):
        @wraps(func)
        def wrapper(*args, **kwargs):
            start_time = time.time()
            log = logger.bind(operation=operation_name, function=func.__name__)
            
            try:
                log.info(f"{operation_name}_started")
                result = func(*args, **kwargs)
                duration = time.time() - start_time
                log.info(f"{operation_name}_completed", duration=duration)
                return result
            except Exception as e:
                duration = time.time() - start_time
                log.error(
                    f"{operation_name}_failed",
                    duration=duration,
                    error=str(e),
                    error_type=type(e).__name__
                )
                raise
        
        return wrapper
    return decorator


# Global instances
metrics_collector = MetricsCollector()
cost_tracker = CostTracker()