ai_service.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446
  1. """AI服务主类 - 整合所有Agent和服务组件"""
  2. import json
  3. import uuid
  4. from datetime import datetime
  5. from typing import Dict, Any, List
  6. from loguru import logger
  7. from openai import OpenAI
  8. import httpx
  9. from ..models.schemas import (
  10. ChatRequest, ChatResponse, QuestionType,
  11. AgentContext, RAGResult, ToolCallResult
  12. )
  13. from ..agents.orchestrator import QuestionOrchestratorAgent
  14. from ..agents.expert_agents import ExpertAgentFactory
  15. from ..agents.summary_agent import SummaryAgent
  16. from ..services.rag_service import RAGService
  17. from ..services.tool_service import ToolService
  18. class AIService:
  19. """AI服务主类,协调所有AI组件"""
  20. def __init__(
  21. self,
  22. openai_api_key: str,
  23. database_config: Dict[str, Any] = None,
  24. api_config: Dict[str, Any] = None,
  25. knowledge_base_path: str = "knowledge_base"
  26. ):
  27. # 初始化OpenAI客户端
  28. self.openai_client = OpenAI(
  29. api_key=openai_api_key,
  30. http_client=httpx.Client(
  31. proxy="socks5://socksuser:8uhb9ijn@35.236.151.13:1080"
  32. )
  33. )
  34. # 初始化各个Agent
  35. self.orchestrator_agent = QuestionOrchestratorAgent(self.openai_client)
  36. self.summary_agent = SummaryAgent(self.openai_client)
  37. # 初始化服务
  38. self.rag_service = RAGService(openai_api_key, knowledge_base_path)
  39. self.tool_service = ToolService(database_config, api_config)
  40. # 初始化会话存储(实际项目中应该使用Redis等缓存)
  41. self.sessions: Dict[str, Dict[str, Any]] = {}
  42. logger.info("AI服务初始化完成")
  43. async def initialize(self):
  44. """初始化AI服务"""
  45. logger.info("开始初始化AI服务...")
  46. # 初始化知识库
  47. await self.rag_service.initialize_default_knowledge()
  48. logger.info("AI服务初始化完成")
  49. async def chat(self, request: ChatRequest) -> ChatResponse:
  50. """
  51. 处理聊天请求
  52. Args:
  53. request: 聊天请求
  54. Returns:
  55. 聊天响应
  56. """
  57. try:
  58. # 生成会话ID
  59. session_id = request.session_id or str(uuid.uuid4())
  60. logger.info(f"处理聊天请求,会话ID: {session_id}")
  61. # 步骤1: 问题分类
  62. classification_result = await self.orchestrator_agent.classify_question(request.message)
  63. question_type = classification_result["question_type"]
  64. logger.info(f"问题分类结果: {question_type}")
  65. # 步骤2: 创建专家Agent
  66. expert_agent = ExpertAgentFactory.create_agent(question_type, self.openai_client)
  67. # 步骤3: 判断是否需要RAG
  68. need_rag = await expert_agent.need_rag(request.message)
  69. rag_results = []
  70. if need_rag:
  71. logger.info("执行RAG查询")
  72. rag_results = await self._execute_rag(request.message, question_type)
  73. # 创建Agent上下文
  74. context = AgentContext(
  75. user_question=request.message,
  76. question_type=question_type,
  77. rag_results=rag_results,
  78. tool_call_results=[],
  79. session_id=session_id,
  80. metadata=request.context or {}
  81. )
  82. # 步骤4: 判断是否需要工具调用
  83. need_tool_call = await expert_agent.need_tool_call(request.message, context)
  84. if need_tool_call:
  85. logger.info("执行智能工具调用")
  86. tool_results = await self._execute_intelligent_tool_calls(request.message, question_type, context)
  87. context.tool_call_results = tool_results
  88. # 步骤5: 生成专家回复
  89. expert_response = await expert_agent.generate_response(context)
  90. # 步骤6: 美化回复
  91. final_response = await self.summary_agent.summarize_and_beautify(
  92. request.message, question_type, expert_response
  93. )
  94. # 更新会话
  95. self._update_session(session_id, request.message, final_response, question_type)
  96. # 构造响应
  97. response = ChatResponse(
  98. message=final_response,
  99. question_type=question_type,
  100. session_id=session_id,
  101. need_rag=need_rag,
  102. need_tool_call=need_tool_call,
  103. metadata={
  104. "classification_confidence": classification_result.get("confidence", 0.0),
  105. "classification_reasoning": classification_result.get("reasoning", ""),
  106. "rag_results_count": len(rag_results),
  107. "tool_calls_count": len(context.tool_call_results)
  108. }
  109. )
  110. logger.info(f"聊天请求处理完成,会话ID: {session_id}")
  111. return response
  112. except Exception as e:
  113. logger.error(f"处理聊天请求失败: {e}")
  114. # 返回错误响应
  115. return ChatResponse(
  116. message="抱歉,我暂时无法处理您的请求,请稍后再试。",
  117. question_type=QuestionType.CHAT,
  118. session_id=session_id or str(uuid.uuid4()),
  119. need_rag=False,
  120. need_tool_call=False,
  121. metadata={"error": str(e)}
  122. )
  123. async def _execute_rag(self, user_question: str, question_type: QuestionType) -> List[RAGResult]:
  124. """执行RAG查询"""
  125. collection_mapping = {
  126. QuestionType.PAGE_NAVIGATION: "page_navigation",
  127. QuestionType.SYSTEM_GUIDE: "system_guide",
  128. QuestionType.PRODUCTION_QA: "production_data",
  129. QuestionType.CHAT: "system_guide" # 闲聊也可能需要查询系统信息
  130. }
  131. collection_name = collection_mapping.get(question_type, "system_guide")
  132. return await self.rag_service.query(user_question, collection_name)
  133. async def _execute_intelligent_tool_calls(
  134. self,
  135. user_question: str,
  136. question_type: QuestionType,
  137. context: AgentContext
  138. ) -> List[ToolCallResult]:
  139. """使用大模型智能决策执行工具调用 - 参考report_generation.py的直接function calling模式"""
  140. try:
  141. # 使用直接function calling模式
  142. tool_results = await self._execute_multi_turn_function_calls(
  143. user_question,
  144. question_type,
  145. context.rag_results,
  146. {
  147. "session_id": context.session_id,
  148. "question_type": question_type.value,
  149. "metadata": context.metadata
  150. }
  151. )
  152. return tool_results
  153. except Exception as e:
  154. logger.error(f"多轮function calling失败: {e}")
  155. # 回退到简单工具调用
  156. return await self._fallback_tool_calls(user_question, question_type, context)
  157. async def _execute_multi_turn_function_calls(
  158. self,
  159. user_question: str,
  160. question_type: QuestionType,
  161. rag_results: List[RAGResult] = None,
  162. context: Dict[str, Any] = None
  163. ) -> List[ToolCallResult]:
  164. """
  165. 执行多轮function calls - 完全参考report_generation.py的实现模式
  166. """
  167. if not self.openai_client:
  168. logger.error("OpenAI客户端未配置")
  169. return []
  170. # 构建增强的提示词,包含多轮调用指令
  171. domain_map = {
  172. QuestionType.PAGE_NAVIGATION: "页面导航",
  173. QuestionType.SYSTEM_GUIDE: "系统使用指导",
  174. QuestionType.PRODUCTION_QA: "生产数据分析",
  175. QuestionType.CHAT: "智能助手"
  176. }
  177. domain = domain_map.get(question_type, "MES系统")
  178. enhanced_prompt = f"""用户请求: {user_question}
  179. 作为{domain}专家,请分析这个请求并使用合适的工具来完成任务。
  180. 你可以进行多轮操作实现你的目标,如执行完操作后还有后续操作,请回复:'尚未完成',如执行完成,请回复'已完成'。
  181. 可用工具说明:
  182. - page_navigation: 页面跳转导航
  183. - database_query: 查询生产数据、设备状态等
  184. - report_generation: 生成各类报表
  185. - data_analysis: 数据分析和趋势预测
  186. - document_generation: 生成业务文档
  187. - external_api_call: 调用外部API
  188. - rag_search: 搜索知识库
  189. - workflow_execution: 执行工作流程
  190. 请根据用户需求制定执行计划,选择合适的工具组合,按步骤完成任务。"""
  191. # 添加上下文信息
  192. if context:
  193. context_str = f"\n\n上下文信息:{json.dumps(context, ensure_ascii=False, indent=2)}"
  194. enhanced_prompt += context_str
  195. # 添加RAG信息
  196. if rag_results:
  197. rag_str = "\n\n相关知识库信息:\n" + "\n".join([f"- {r.content}" for r in rag_results[:3]])
  198. enhanced_prompt += rag_str
  199. messages = [
  200. {"role": "user", "content": enhanced_prompt}
  201. ]
  202. # 获取工具schemas
  203. functions = self.tool_service.get_tool_schemas()
  204. tool_results = []
  205. try:
  206. # 首次调用
  207. response = self.openai_client.chat.completions.create(
  208. model="gpt-3.5-turbo",
  209. messages=messages,
  210. functions=functions,
  211. function_call="auto",
  212. temperature=0.1
  213. )
  214. message = response.choices[0].message
  215. logger.info(f"首次响应: {message.content}")
  216. messages.append({
  217. "role": "assistant",
  218. "content": message.content,
  219. "function_call": message.function_call
  220. })
  221. # 多轮调用循环 - 完全参考report_generation.py
  222. max_turns = 10
  223. turn = 0
  224. while turn < max_turns:
  225. turn += 1
  226. # 检查是否完成
  227. if message.content and "已完成" in message.content:
  228. logger.info("任务完成")
  229. break
  230. # 执行工具调用
  231. if message.function_call:
  232. function_name = message.function_call.name
  233. try:
  234. function_args = json.loads(message.function_call.arguments)
  235. except json.JSONDecodeError as e:
  236. logger.error(f"解析函数参数失败: {e}")
  237. function_args = {}
  238. # 执行工具
  239. tool_result = await self.tool_service.execute_tool(function_name, function_args)
  240. tool_results.append(tool_result)
  241. # 准备结果内容
  242. if tool_result.success:
  243. result_content = json.dumps(tool_result.result, ensure_ascii=False) if tool_result.result else "执行成功"
  244. else:
  245. result_content = tool_result.error_message or "执行失败"
  246. # 添加工具执行结果到消息历史
  247. messages.append({
  248. "role": "function",
  249. "name": function_name,
  250. "content": result_content
  251. })
  252. logger.info(f"第{turn}轮 - 工具 {function_name} 执行{'成功' if tool_result.success else '失败'}")
  253. # 继续下一轮调用
  254. response = self.openai_client.chat.completions.create(
  255. model="gpt-3.5-turbo",
  256. messages=messages,
  257. functions=functions,
  258. function_call="auto",
  259. temperature=0.1
  260. )
  261. message = response.choices[0].message
  262. logger.info(f"第{turn}轮响应: {message.content}")
  263. messages.append({
  264. "role": "assistant",
  265. "content": message.content,
  266. "function_call": message.function_call
  267. })
  268. else:
  269. # 没有function call但也没完成,询问并尝试继续
  270. logger.warning(f"第{turn}轮 - 没有工具调用但任务未完成")
  271. messages.append({
  272. "role": "user",
  273. "content": "你回复了尚未完成,但并没有返回function call,是遇到什么问题了吗?如果需要继续执行,请继续回复:尚未完成"
  274. })
  275. response = self.openai_client.chat.completions.create(
  276. model="gpt-3.5-turbo",
  277. messages=messages,
  278. functions=functions,
  279. function_call="auto",
  280. temperature=0.1
  281. )
  282. message = response.choices[0].message
  283. messages.append({
  284. "role": "assistant",
  285. "content": message.content,
  286. "function_call": message.function_call
  287. })
  288. if turn >= max_turns:
  289. logger.warning(f"达到最大轮次限制({max_turns}),停止执行")
  290. except Exception as e:
  291. logger.error(f"多轮function calling执行失败: {e}")
  292. return []
  293. return tool_results
  294. async def _fallback_tool_calls(
  295. self,
  296. user_question: str,
  297. question_type: QuestionType,
  298. context: AgentContext
  299. ) -> List[ToolCallResult]:
  300. """回退的简单工具调用逻辑"""
  301. tool_calls = []
  302. if question_type == QuestionType.PAGE_NAVIGATION:
  303. # 页面跳转需要调用导航工具
  304. for rag_result in context.rag_results:
  305. if "page" in rag_result.source:
  306. # 从RAG结果中提取页面路径
  307. page_path = self._extract_page_path(rag_result.content)
  308. if page_path:
  309. tool_calls.append({
  310. "tool_name": "page_navigation",
  311. "params": {
  312. "page_path": page_path,
  313. "operation": user_question
  314. }
  315. })
  316. break
  317. elif question_type == QuestionType.PRODUCTION_QA:
  318. # 生产QA需要查询数据库
  319. query_type = self._determine_query_type(user_question)
  320. tool_calls.append({
  321. "tool_name": "database_query",
  322. "params": {
  323. "query_type": query_type,
  324. "filters": {}
  325. }
  326. })
  327. # 执行工具调用
  328. if tool_calls:
  329. return await self.tool_service.batch_execute_tools(tool_calls)
  330. return []
  331. def _extract_page_path(self, content: str) -> str:
  332. """从内容中提取页面路径"""
  333. # 简单的路径提取逻辑,实际项目中可以更复杂
  334. if "/material/inbound" in content:
  335. return "/material/inbound"
  336. elif "/material/outbound" in content:
  337. return "/material/outbound"
  338. elif "/production/plan" in content:
  339. return "/production/plan"
  340. return ""
  341. def _determine_query_type(self, user_question: str) -> str:
  342. """根据用户问题确定查询类型"""
  343. question_lower = user_question.lower()
  344. if any(keyword in question_lower for keyword in ["生产", "订单", "效率", "产量"]):
  345. return "production_status"
  346. elif any(keyword in question_lower for keyword in ["物料", "库存", "入库", "出库"]):
  347. return "material_inventory"
  348. elif any(keyword in question_lower for keyword in ["设备", "机器", "故障", "维护"]):
  349. return "equipment_status"
  350. return "production_status"
  351. def _update_session(self, session_id: str, user_message: str, ai_response: str, question_type: QuestionType):
  352. """更新会话信息"""
  353. current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
  354. if session_id not in self.sessions:
  355. self.sessions[session_id] = {
  356. "messages": [],
  357. "created_at": current_time,
  358. "last_activity": current_time
  359. }
  360. self.sessions[session_id]["messages"].extend([
  361. {"role": "user", "content": user_message, "question_type": question_type.value},
  362. {"role": "assistant", "content": ai_response, "question_type": question_type.value}
  363. ])
  364. self.sessions[session_id]["last_activity"] = current_time
  365. def get_session_history(self, session_id: str) -> Dict[str, Any]:
  366. """获取会话历史"""
  367. return self.sessions.get(session_id, {"messages": []})