# -*- coding: utf-8 -*-
"""
AI日志格式化器

统一不同AI服务的日志格式
"""
import datetime
from typing import Optional, Dict, Any
from enum import Enum


class LogLevel(Enum):
    """日志级别"""
    DEBUG = "DEBUG"
    INFO = "INFO"
    WARNING = "WARN"
    ERROR = "ERROR"
    SUCCESS = "SUCCESS"
    STAGE = "STAGE"  # 阶段信息
    TOOL = "TOOL"    # 工具调用
    SYSTEM = "SYSTEM"  # 系统消息


class AILogFormatter:
    """AI日志格式化器"""
    
    # 日志格式模板
    # 格式: [YYYY-MM-DD HH:MM:SS] [LEVEL] [SERVICE] message
    FORMAT_TEMPLATE = "[{timestamp}] [{level}] [{service}] {message}"
    
    # 阶段格式模板
    STAGE_TEMPLATE = "[{timestamp}] [{level}] [{service}] 🎯 [阶段 {stage}] {stage_name}"
    
    # 工具格式模板
    TOOL_TEMPLATE = "[{timestamp}] [{level}] [{service}] 🔧 {tool_name}: {message}"
    
    # 系统消息格式模板
    SYSTEM_TEMPLATE = "[{timestamp}] [{level}] [{service}] ⚙️  {message}"
    
    # 分隔线格式
    SEPARATOR_TEMPLATE = "[{timestamp}] [{level}] [{service}] {separator}"
    
    def __init__(self, service_name: str):
        """
        初始化日志格式化器
        
        Args:
            service_name: 服务名称 (如 'AIService', 'Augment', 'ClaudeRouter')
        """
        self.service_name = service_name
    
    def format(
        self,
        message: str,
        level: LogLevel = LogLevel.INFO,
        stage: Optional[int] = None,
        stage_name: Optional[str] = None,
        tool_name: Optional[str] = None,
        **kwargs
    ) -> str:
        """
        格式化日志消息
        
        Args:
            message: 日志消息
            level: 日志级别
            stage: 阶段编号（可选）
            stage_name: 阶段名称（可选）
            tool_name: 工具名称（可选）
            **kwargs: 其他参数
            
        Returns:
            格式化后的日志字符串
        """
        timestamp = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
        level_str = level.value
        
        # 阶段信息
        if stage is not None and stage_name:
            return self.STAGE_TEMPLATE.format(
                timestamp=timestamp,
                level=level_str,
                service=self.service_name,
                stage=stage,
                stage_name=stage_name
            )
        
        # 工具调用
        if tool_name:
            return self.TOOL_TEMPLATE.format(
                timestamp=timestamp,
                level=level_str,
                service=self.service_name,
                tool_name=tool_name,
                message=message
            )
        
        # 系统消息
        if level == LogLevel.SYSTEM:
            return self.SYSTEM_TEMPLATE.format(
                timestamp=timestamp,
                level=level_str,
                service=self.service_name,
                message=message
            )
        
        # 普通消息
        return self.FORMAT_TEMPLATE.format(
            timestamp=timestamp,
            level=level_str,
            service=self.service_name,
            message=message
        )
    
    def format_separator(self, message: Optional[str] = None) -> str:
        """
        格式化分隔线
        
        Args:
            message: 分隔线中的消息（可选）
            
        Returns:
            格式化后的分隔线
        """
        timestamp = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
        if message:
            return self.format(message, level=LogLevel.SYSTEM)
        else:
            return self.SEPARATOR_TEMPLATE.format(
                timestamp=timestamp,
                level=LogLevel.SYSTEM.value,
                service=self.service_name,
                separator='=' * 60
            )
    
    def format_stage(
        self,
        stage: int,
        stage_name: str,
        status: str = "开始"
    ) -> str:
        """
        格式化阶段信息
        
        Args:
            stage: 阶段编号
            stage_name: 阶段名称
            status: 状态（开始/完成/跳过等）
            
        Returns:
            格式化后的阶段日志
        """
        message = f"🎯 [阶段 {stage}] {stage_name} - {status}"
        return self.format(message, level=LogLevel.STAGE, stage=stage, stage_name=stage_name)
    
    def format_tool(
        self,
        tool_name: str,
        message: str,
        level: LogLevel = LogLevel.INFO
    ) -> str:
        """
        格式化工具调用信息
        
        Args:
            tool_name: 工具名称
            message: 工具消息
            level: 日志级别
            
        Returns:
            格式化后的工具日志
        """
        return self.format(message, level=level, tool_name=tool_name)
    
    def format_system(self, message: str) -> str:
        """
        格式化系统消息
        
        Args:
            message: 系统消息
            
        Returns:
            格式化后的系统日志
        """
        return self.format(message, level=LogLevel.SYSTEM)
    
    def format_error(self, message: str, error: Optional[Exception] = None) -> str:
        """
        格式化错误消息
        
        Args:
            message: 错误消息
            error: 异常对象（可选）
            
        Returns:
            格式化后的错误日志
        """
        if error:
            error_msg = f"{message}: {str(error)}"
            if hasattr(error, '__traceback__'):
                import traceback
                error_msg += f"\n{traceback.format_exc()}"
            return self.format(error_msg, level=LogLevel.ERROR)
        return self.format(message, level=LogLevel.ERROR)
    
    def format_success(self, message: str) -> str:
        """
        格式化成功消息
        
        Args:
            message: 成功消息
            
        Returns:
            格式化后的成功日志
        """
        return self.format(message, level=LogLevel.SUCCESS)
    
    def format_warning(self, message: str) -> str:
        """
        格式化警告消息
        
        Args:
            message: 警告消息
            
        Returns:
            格式化后的警告日志
        """
        return self.format(message, level=LogLevel.WARNING)
    
    def format_info(self, message: str) -> str:
        """
        格式化信息消息
        
        Args:
            message: 信息消息
            
        Returns:
            格式化后的信息日志
        """
        return self.format(message, level=LogLevel.INFO)
    
    def format_debug(self, message: str) -> str:
        """
        格式化调试消息
        
        Args:
            message: 调试消息
            
        Returns:
            格式化后的调试日志
        """
        return self.format(message, level=LogLevel.DEBUG)


class UnifiedLogWriter:
    """统一的日志写入器"""
    
    def __init__(self, log_file: str, service_name: str):
        """
        初始化统一日志写入器
        
        Args:
            log_file: 日志文件路径
            service_name: 服务名称
        """
        self.log_file = log_file
        self.formatter = AILogFormatter(service_name)
        self._file_handle = None
    
    def __enter__(self):
        """上下文管理器入口"""
        self._file_handle = open(self.log_file, 'w', encoding='utf-8')
        return self
    
    def __exit__(self, exc_type, exc_val, exc_tb):
        """上下文管理器出口"""
        if self._file_handle:
            self._file_handle.close()
    
    def write(
        self,
        message: str,
        level: LogLevel = LogLevel.INFO,
        stage: Optional[int] = None,
        stage_name: Optional[str] = None,
        tool_name: Optional[str] = None,
        flush: bool = True
    ):
        """
        写入日志
        
        Args:
            message: 日志消息
            level: 日志级别
            stage: 阶段编号（可选）
            stage_name: 阶段名称（可选）
            tool_name: 工具名称（可选）
            flush: 是否立即刷新
        """
        formatted = self.formatter.format(
            message=message,
            level=level,
            stage=stage,
            stage_name=stage_name,
            tool_name=tool_name
        )
        
        if self._file_handle:
            self._file_handle.write(formatted + '\n')
            if flush:
                self._file_handle.flush()
    
    def write_raw(self, line: str, flush: bool = True):
        """
        写入原始行（不格式化）
        
        Args:
            line: 原始行
            flush: 是否立即刷新
        """
        if self._file_handle:
            self._file_handle.write(line)
            if flush:
                self._file_handle.flush()
    
    def write_stage(self, stage: int, stage_name: str, status: str = "开始"):
        """写入阶段信息"""
        formatted = self.formatter.format_stage(stage, stage_name, status)
        if self._file_handle:
            self._file_handle.write(formatted + '\n')
            self._file_handle.flush()
    
    def write_tool(self, tool_name: str, message: str, level: LogLevel = LogLevel.INFO):
        """写入工具调用信息"""
        formatted = self.formatter.format_tool(tool_name, message, level)
        if self._file_handle:
            self._file_handle.write(formatted + '\n')
            self._file_handle.flush()
    
    def write_system(self, message: str):
        """写入系统消息"""
        formatted = self.formatter.format_system(message)
        if self._file_handle:
            self._file_handle.write(formatted + '\n')
            self._file_handle.flush()
    
    def write_error(self, message: str, error: Optional[Exception] = None):
        """写入错误消息"""
        formatted = self.formatter.format_error(message, error)
        if self._file_handle:
            self._file_handle.write(formatted + '\n')
            self._file_handle.flush()
    
    def write_success(self, message: str):
        """写入成功消息"""
        formatted = self.formatter.format_success(message)
        if self._file_handle:
            self._file_handle.write(formatted + '\n')
            self._file_handle.flush()
    
    def write_warning(self, message: str):
        """写入警告消息"""
        formatted = self.formatter.format_warning(message)
        if self._file_handle:
            self._file_handle.write(formatted + '\n')
            self._file_handle.flush()
    
    def write_info(self, message: str):
        """写入信息消息"""
        formatted = self.formatter.format_info(message)
        if self._file_handle:
            self._file_handle.write(formatted + '\n')
            self._file_handle.flush()
    
    def write_separator(self, message: Optional[str] = None):
        """写入分隔线"""
        formatted = self.formatter.format_separator(message)
        if self._file_handle:
            self._file_handle.write(formatted + '\n')
            self._file_handle.flush()

