# -*- coding: utf-8 -*-
"""
上下文管理器

负责管理 AI 对话的上下文，包括消息压缩、阶段摘要等
"""
import logging
from typing import Dict, List, Any, Optional

from ..core.stage_detector import StageDetector

logger = logging.getLogger(__name__)


class ContextManager:
    """上下文管理器
    
    负责管理对话上下文，包括：
    - 消息压缩
    - 阶段摘要管理
    - 文件跟踪
    """
    
    def __init__(self, log_callback=None):
        """初始化上下文管理器
        
        Args:
            log_callback: 日志回调函数
        """
        self.log_callback = log_callback
        self.stage_artifacts: Dict[int, Dict[str, Any]] = {}
        self.generated_files: List[str] = []
        self.context_config = {
            'max_messages': 100,
            'max_tool_results_length': 10000,
            'compress_threshold': 99999,  # 禁用压缩（设置为极大值）
        }
    
    def _log(self, level: str, message: str):
        """记录日志"""
        if self.log_callback:
            self.log_callback(level, message)
        log_func = getattr(logger, level, logger.info)
        log_func(message)
    
    def extract_stage_artifact(self, stage: int, content: str) -> Dict[str, Any]:
        """从 AI 输出中提取阶段关键信息
        
        Args:
            stage: 阶段号
            content: AI 输出内容
            
        Returns:
            阶段摘要字典
        """
        info = StageDetector.extract_stage_info(stage, content)
        artifact = {'stage': stage, 'raw_summary': ''}
        if info:
            artifact.update(info)
        return artifact
    
    def save_stage_artifact(self, stage: int, content: str):
        """保存阶段关键信息
        
        Args:
            stage: 阶段号
            content: AI 输出内容
        """
        artifact = self.extract_stage_artifact(stage, content)
        self.stage_artifacts[stage] = artifact
        self._log('debug', f'保存阶段 {stage} 摘要: {artifact.get("raw_summary", "")}')
    
    def track_generated_file(self, file_path: str):
        """跟踪生成的文件
        
        Args:
            file_path: 文件路径
        """
        if file_path and file_path not in self.generated_files:
            self.generated_files.append(file_path)
    
    def build_context_summary(self) -> str:
        """构建上下文摘要
        
        Returns:
            上下文摘要字符串
        """
        return StageDetector.build_context_summary(self.stage_artifacts, self.generated_files)
    
    def compress_messages(
        self, 
        messages: List[Dict], 
        current_stage: int,
        filtered_stages: Optional[List[Dict]] = None
    ) -> List[Dict]:
        """压缩消息列表，保留关键信息
        
        策略：
        1. 保留系统提示词
        2. 保留阶段摘要
        3. 保留最近 N 条消息
        4. 截断过长的工具结果
        
        Args:
            messages: 原始消息列表
            current_stage: 当前阶段
            filtered_stages: 过滤后的阶段列表（可选）
            
        Returns:
            压缩后的消息列表
        """
        if len(messages) < self.context_config['compress_threshold']:
            return messages
        
        self._log('info', f'触发上下文压缩: {len(messages)} 条消息 → 压缩中...')
        
        compressed = []
        max_messages = self.context_config['max_messages']
        max_tool_len = self.context_config['max_tool_results_length']
        
        # 1. 保留系统提示词
        if messages and messages[0].get('role') == 'system':
            compressed.append(messages[0])
            messages = messages[1:]
        
        # 2. 插入上下文摘要（如果有之前阶段的信息）
        context_summary = self.build_context_summary()
        if context_summary:
            compressed.append({
                'role': 'user',
                'content': context_summary + "\n\n请基于以上历史信息继续当前任务。"
            })
            # 获取阶段ID用于获取阶段名称
            stage_id = self._get_stage_id_by_index(current_stage, filtered_stages) if filtered_stages else str(current_stage)
            stage_name = self._get_stage_name(stage_id, filtered_stages)
            compressed.append({
                'role': 'assistant',
                'content': f"好的，我已了解之前阶段的关键信息。当前在阶段 {current_stage}（{stage_name}），继续执行。"
            })
        
        # 3. 保留最近的消息，确保 tool 消息和对应的 tool_calls 成对保留
        recent_messages = messages[-max_messages:] if len(messages) > max_messages else messages
        
        # 检查第一条消息是否是 tool 消息，如果是则向前查找对应的 assistant 消息
        if recent_messages and recent_messages[0].get('role') == 'tool':
            start_idx = len(messages) - len(recent_messages)
            if start_idx > 0:
                for i in range(start_idx - 1, -1, -1):
                    msg = messages[i]
                    recent_messages = [msg] + recent_messages
                    if msg.get('role') == 'assistant' and msg.get('tool_calls'):
                        break
        
        for msg in recent_messages:
            new_msg = msg.copy()
            
            # 截断过长的工具结果
            if msg.get('role') == 'tool' and len(msg.get('content', '')) > max_tool_len:
                content = msg['content']
                new_msg['content'] = content[:max_tool_len] + f"\n... [已截断，原长度 {len(content)} 字符]"
            
            # 截断过长的助手消息（但保留关键结构）
            if msg.get('role') == 'assistant' and len(msg.get('content', '')) > 5000:
                content = msg['content']
                if '|' in content:
                    new_msg['content'] = content[:5000] + "\n... [已截断]"
                else:
                    new_msg['content'] = content[:3000] + "\n... [已截断]"
            
            compressed.append(new_msg)
        
        self._log('info', f'上下文压缩完成: {len(messages) + 1} → {len(compressed)} 条消息')
        return compressed
    
    def _get_stage_name(self, stage_id: str, filtered_stages: Optional[List[Dict]] = None) -> str:
        """获取阶段名称（辅助方法）"""
        if filtered_stages:
            for stage in filtered_stages:
                if str(stage.get('id', '')) == str(stage_id):
                    return stage.get('name', f'阶段 {stage_id}')
        try:
            numeric_id = int(stage_id) if stage_id.replace('.', '').isdigit() else None
            if numeric_id is not None:
                return StageDetector.get_stage_name(numeric_id)
        except:
            pass
        return f'阶段 {stage_id}'
    
    def _get_stage_id_by_index(self, index: int, filtered_stages: Optional[List[Dict]] = None) -> Optional[str]:
        """根据阶段索引获取阶段 ID（辅助方法）"""
        if filtered_stages and 0 <= index < len(filtered_stages):
            return str(filtered_stages[index].get('id', ''))
        return None
    
    def clear(self):
        """清空上下文"""
        self.stage_artifacts.clear()
        self.generated_files.clear()

