12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879 |
- """基本功能测试"""
- import pytest
- import asyncio
- from unittest.mock import Mock, patch
- from src.models.schemas import ChatRequest, QuestionType
- from src.agents.orchestrator import QuestionOrchestratorAgent
- from src.agents.expert_agents import ExpertAgentFactory
- class TestBasicFunctionality:
- """基本功能测试类"""
-
- def test_question_types_enum(self):
- """测试问题类型枚举"""
- assert QuestionType.PAGE_NAVIGATION == "page_navigation"
- assert QuestionType.SYSTEM_GUIDE == "system_guide"
- assert QuestionType.PRODUCTION_QA == "production_qa"
- assert QuestionType.CHAT == "chat"
-
- def test_chat_request_model(self):
- """测试聊天请求模型"""
- request = ChatRequest(message="测试消息")
- assert request.message == "测试消息"
- assert request.session_id is None
- assert request.context == {}
-
- def test_expert_agent_factory(self):
- """测试专家Agent工厂"""
- mock_client = Mock()
-
- # 测试每种问题类型都能创建对应的Agent
- for question_type in QuestionType:
- agent = ExpertAgentFactory.create_agent(question_type, mock_client)
- assert agent is not None
- assert hasattr(agent, 'need_rag')
- assert hasattr(agent, 'need_tool_call')
- assert hasattr(agent, 'generate_response')
- class TestAPIModels:
- """API模型测试"""
-
- def test_chat_request_validation(self):
- """测试聊天请求验证"""
- # 正常请求
- request = ChatRequest(message="你好")
- assert request.message == "你好"
-
- # 带会话ID的请求
- request = ChatRequest(message="你好", session_id="test-session")
- assert request.session_id == "test-session"
-
- # 带上下文的请求
- request = ChatRequest(
- message="你好",
- context={"user_id": "123"}
- )
- assert request.context["user_id"] == "123"
- @pytest.mark.asyncio
- async def test_async_functionality():
- """测试异步功能"""
- # 这里可以添加异步测试
- pass
- if __name__ == "__main__":
- # 运行基本测试
- test_basic = TestBasicFunctionality()
- test_basic.test_question_types_enum()
- test_basic.test_chat_request_model()
-
- test_api = TestAPIModels()
- test_api.test_chat_request_validation()
-
- print("✅ 所有基本测试通过")
|