"""AI服务主类 - 整合所有Agent和服务组件""" import json import uuid from datetime import datetime from typing import Dict, Any, List from loguru import logger from openai import OpenAI import httpx from ..models.schemas import ( ChatRequest, ChatResponse, QuestionType, AgentContext, RAGResult, ToolCallResult ) from ..agents.orchestrator import QuestionOrchestratorAgent from ..agents.expert_agents import ExpertAgentFactory from ..agents.summary_agent import SummaryAgent from ..services.rag_service import RAGService from ..services.tool_service import ToolService class AIService: """AI服务主类,协调所有AI组件""" def __init__( self, openai_api_key: str, database_config: Dict[str, Any] = None, api_config: Dict[str, Any] = None, knowledge_base_path: str = "knowledge_base" ): # 初始化OpenAI客户端 self.openai_client = OpenAI( api_key=openai_api_key, http_client=httpx.Client( proxy="socks5://socksuser:8uhb9ijn@35.236.151.13:1080" ) ) # 初始化各个Agent self.orchestrator_agent = QuestionOrchestratorAgent(self.openai_client) self.summary_agent = SummaryAgent(self.openai_client) # 初始化服务 self.rag_service = RAGService(openai_api_key, knowledge_base_path) self.tool_service = ToolService(database_config, api_config) # 初始化会话存储(实际项目中应该使用Redis等缓存) self.sessions: Dict[str, Dict[str, Any]] = {} logger.info("AI服务初始化完成") async def initialize(self): """初始化AI服务""" logger.info("开始初始化AI服务...") # 初始化知识库 await self.rag_service.initialize_default_knowledge() logger.info("AI服务初始化完成") async def chat(self, request: ChatRequest) -> ChatResponse: """ 处理聊天请求 Args: request: 聊天请求 Returns: 聊天响应 """ try: # 生成会话ID session_id = request.session_id or str(uuid.uuid4()) logger.info(f"处理聊天请求,会话ID: {session_id}") # 步骤1: 问题分类 classification_result = await self.orchestrator_agent.classify_question(request.message) question_type = classification_result["question_type"] logger.info(f"问题分类结果: {question_type}") # 步骤2: 创建专家Agent expert_agent = ExpertAgentFactory.create_agent(question_type, self.openai_client) # 步骤3: 判断是否需要RAG need_rag = await expert_agent.need_rag(request.message) rag_results = [] if need_rag: logger.info("执行RAG查询") rag_results = await self._execute_rag(request.message, question_type) # 创建Agent上下文 context = AgentContext( user_question=request.message, question_type=question_type, rag_results=rag_results, tool_call_results=[], session_id=session_id, metadata=request.context or {} ) # 步骤4: 判断是否需要工具调用 need_tool_call = await expert_agent.need_tool_call(request.message, context) if need_tool_call: logger.info("执行智能工具调用") tool_results = await self._execute_intelligent_tool_calls(request.message, question_type, context) context.tool_call_results = tool_results # 步骤5: 生成专家回复 expert_response = await expert_agent.generate_response(context) # 步骤6: 美化回复 final_response = await self.summary_agent.summarize_and_beautify( request.message, question_type, expert_response ) # 更新会话 self._update_session(session_id, request.message, final_response, question_type) # 构造响应 response = ChatResponse( message=final_response, question_type=question_type, session_id=session_id, need_rag=need_rag, need_tool_call=need_tool_call, metadata={ "classification_confidence": classification_result.get("confidence", 0.0), "classification_reasoning": classification_result.get("reasoning", ""), "rag_results_count": len(rag_results), "tool_calls_count": len(context.tool_call_results) } ) logger.info(f"聊天请求处理完成,会话ID: {session_id}") return response except Exception as e: logger.error(f"处理聊天请求失败: {e}") # 返回错误响应 return ChatResponse( message="抱歉,我暂时无法处理您的请求,请稍后再试。", question_type=QuestionType.CHAT, session_id=session_id or str(uuid.uuid4()), need_rag=False, need_tool_call=False, metadata={"error": str(e)} ) async def _execute_rag(self, user_question: str, question_type: QuestionType) -> List[RAGResult]: """执行RAG查询""" collection_mapping = { QuestionType.PAGE_NAVIGATION: "page_navigation", QuestionType.SYSTEM_GUIDE: "system_guide", QuestionType.PRODUCTION_QA: "production_data", QuestionType.CHAT: "system_guide" # 闲聊也可能需要查询系统信息 } collection_name = collection_mapping.get(question_type, "system_guide") return await self.rag_service.query(user_question, collection_name) async def _execute_intelligent_tool_calls( self, user_question: str, question_type: QuestionType, context: AgentContext ) -> List[ToolCallResult]: """使用大模型智能决策执行工具调用 - 参考report_generation.py的直接function calling模式""" try: # 使用直接function calling模式 tool_results = await self._execute_multi_turn_function_calls( user_question, question_type, context.rag_results, { "session_id": context.session_id, "question_type": question_type.value, "metadata": context.metadata } ) return tool_results except Exception as e: logger.error(f"多轮function calling失败: {e}") # 回退到简单工具调用 return await self._fallback_tool_calls(user_question, question_type, context) async def _execute_multi_turn_function_calls( self, user_question: str, question_type: QuestionType, rag_results: List[RAGResult] = None, context: Dict[str, Any] = None ) -> List[ToolCallResult]: """ 执行多轮function calls - 完全参考report_generation.py的实现模式 """ if not self.openai_client: logger.error("OpenAI客户端未配置") return [] # 构建增强的提示词,包含多轮调用指令 domain_map = { QuestionType.PAGE_NAVIGATION: "页面导航", QuestionType.SYSTEM_GUIDE: "系统使用指导", QuestionType.PRODUCTION_QA: "生产数据分析", QuestionType.CHAT: "智能助手" } domain = domain_map.get(question_type, "MES系统") enhanced_prompt = f"""用户请求: {user_question} 作为{domain}专家,请分析这个请求并使用合适的工具来完成任务。 你可以进行多轮操作实现你的目标,如执行完操作后还有后续操作,请回复:'尚未完成',如执行完成,请回复'已完成'。 可用工具说明: - page_navigation: 页面跳转导航 - database_query: 查询生产数据、设备状态等 - report_generation: 生成各类报表 - data_analysis: 数据分析和趋势预测 - document_generation: 生成业务文档 - external_api_call: 调用外部API - rag_search: 搜索知识库 - workflow_execution: 执行工作流程 请根据用户需求制定执行计划,选择合适的工具组合,按步骤完成任务。""" # 添加上下文信息 if context: context_str = f"\n\n上下文信息:{json.dumps(context, ensure_ascii=False, indent=2)}" enhanced_prompt += context_str # 添加RAG信息 if rag_results: rag_str = "\n\n相关知识库信息:\n" + "\n".join([f"- {r.content}" for r in rag_results[:3]]) enhanced_prompt += rag_str messages = [ {"role": "user", "content": enhanced_prompt} ] # 获取工具schemas functions = self.tool_service.get_tool_schemas() tool_results = [] try: # 首次调用 response = self.openai_client.chat.completions.create( model="gpt-3.5-turbo", messages=messages, functions=functions, function_call="auto", temperature=0.1 ) message = response.choices[0].message logger.info(f"首次响应: {message.content}") messages.append({ "role": "assistant", "content": message.content, "function_call": message.function_call }) # 多轮调用循环 - 完全参考report_generation.py max_turns = 10 turn = 0 while turn < max_turns: turn += 1 # 检查是否完成 if message.content and "已完成" in message.content: logger.info("任务完成") break # 执行工具调用 if message.function_call: function_name = message.function_call.name try: function_args = json.loads(message.function_call.arguments) except json.JSONDecodeError as e: logger.error(f"解析函数参数失败: {e}") function_args = {} # 执行工具 tool_result = await self.tool_service.execute_tool(function_name, function_args) tool_results.append(tool_result) # 准备结果内容 if tool_result.success: result_content = json.dumps(tool_result.result, ensure_ascii=False) if tool_result.result else "执行成功" else: result_content = tool_result.error_message or "执行失败" # 添加工具执行结果到消息历史 messages.append({ "role": "function", "name": function_name, "content": result_content }) logger.info(f"第{turn}轮 - 工具 {function_name} 执行{'成功' if tool_result.success else '失败'}") # 继续下一轮调用 response = self.openai_client.chat.completions.create( model="gpt-3.5-turbo", messages=messages, functions=functions, function_call="auto", temperature=0.1 ) message = response.choices[0].message logger.info(f"第{turn}轮响应: {message.content}") messages.append({ "role": "assistant", "content": message.content, "function_call": message.function_call }) else: # 没有function call但也没完成,询问并尝试继续 logger.warning(f"第{turn}轮 - 没有工具调用但任务未完成") messages.append({ "role": "user", "content": "你回复了尚未完成,但并没有返回function call,是遇到什么问题了吗?如果需要继续执行,请继续回复:尚未完成" }) response = self.openai_client.chat.completions.create( model="gpt-3.5-turbo", messages=messages, functions=functions, function_call="auto", temperature=0.1 ) message = response.choices[0].message messages.append({ "role": "assistant", "content": message.content, "function_call": message.function_call }) if turn >= max_turns: logger.warning(f"达到最大轮次限制({max_turns}),停止执行") except Exception as e: logger.error(f"多轮function calling执行失败: {e}") return [] return tool_results async def _fallback_tool_calls( self, user_question: str, question_type: QuestionType, context: AgentContext ) -> List[ToolCallResult]: """回退的简单工具调用逻辑""" tool_calls = [] if question_type == QuestionType.PAGE_NAVIGATION: # 页面跳转需要调用导航工具 for rag_result in context.rag_results: if "page" in rag_result.source: # 从RAG结果中提取页面路径 page_path = self._extract_page_path(rag_result.content) if page_path: tool_calls.append({ "tool_name": "page_navigation", "params": { "page_path": page_path, "operation": user_question } }) break elif question_type == QuestionType.PRODUCTION_QA: # 生产QA需要查询数据库 query_type = self._determine_query_type(user_question) tool_calls.append({ "tool_name": "database_query", "params": { "query_type": query_type, "filters": {} } }) # 执行工具调用 if tool_calls: return await self.tool_service.batch_execute_tools(tool_calls) return [] def _extract_page_path(self, content: str) -> str: """从内容中提取页面路径""" # 简单的路径提取逻辑,实际项目中可以更复杂 if "/material/inbound" in content: return "/material/inbound" elif "/material/outbound" in content: return "/material/outbound" elif "/production/plan" in content: return "/production/plan" return "" def _determine_query_type(self, user_question: str) -> str: """根据用户问题确定查询类型""" question_lower = user_question.lower() if any(keyword in question_lower for keyword in ["生产", "订单", "效率", "产量"]): return "production_status" elif any(keyword in question_lower for keyword in ["物料", "库存", "入库", "出库"]): return "material_inventory" elif any(keyword in question_lower for keyword in ["设备", "机器", "故障", "维护"]): return "equipment_status" return "production_status" def _update_session(self, session_id: str, user_message: str, ai_response: str, question_type: QuestionType): """更新会话信息""" current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") if session_id not in self.sessions: self.sessions[session_id] = { "messages": [], "created_at": current_time, "last_activity": current_time } self.sessions[session_id]["messages"].extend([ {"role": "user", "content": user_message, "question_type": question_type.value}, {"role": "assistant", "content": ai_response, "question_type": question_type.value} ]) self.sessions[session_id]["last_activity"] = current_time def get_session_history(self, session_id: str) -> Dict[str, Any]: """获取会话历史""" return self.sessions.get(session_id, {"messages": []})