该类旨在包装远程托管的大型语言模型,并且必须对服务进行 API 调用以获取结果。这是一个例子。
class ModelWrapper(AbstractLLMInterface):
"""The Claude 3 Sonnet model wrapper following the interface."""
def __init__(
self,
region: str = CONFIGS["GCP"]["REGION"],
project: str = get_gcp_project_id(),
model: str = CONFIGS["MODELS"]["SONNET_ID"],
) -> None:
"""Set up the claude client using the region and project."""
self.client: AnthropicVertex = AnthropicVertex(region=region, project_id=project)
self.model_name: str = model
self.role_key: str = "role"
self.content_key: str = "content"
self.user_key: str = "user"
llm_logger.debug(msg=f"Initialised sonnet client for {region}, {project} and {self.model_name}.")
def get_completion(self, user_prompt: str, system_prompt: str, history: List[Correspondence]) -> Iterator[str]:
"""
Fetch a response from the model.This requires an egress request to GCP and the
service for Anthropic model must be enabled in the VertexAI console.
"""
# This is where the API call to GCP service happens
return self.client.messages.stream(user_prompt, system_prompt, history)
我知道用于 python 单元测试的
MagicMocks
是可以配置为返回我想要的任何内容的对象。但在本例中,构造函数的所有参数都是简单字符串,并且 client
是在类内部构建的。所以模拟客户端没有空间,对吧?
这是否意味着任何单元测试都必须进行API调用?或者类设计不正确?任何帮助将不胜感激。
该类应该是可测试的,而无需调用 API。
我认为最简单的方法是通过构造函数传递客户端,而不是在内部实例化它:
class ModelWrapper(AbstractLLMInterface):
def __init__(
self,
client: Optional[AnthropicVertex] = None,
region: str = CONFIGS["GCP"]["REGION"],
project: str = get_gcp_project_id(),
model: str = CONFIGS["MODELS"]["SONNET_ID"],
) -> None:
if client is None:
# Build the real client
client = AnthropicVertex(region=region, project_id=project)
self.client = client
self.model_name: str = model
def get_completion(self, user_prompt: str, system_prompt: str, history: List[Correspondence]) -> Iterator[str]:
return self.client.messages.stream(user_prompt, system_prompt, history)
然后你可以像这样测试:
from unittest.mock import MagicMock
def test_model_wrapper():
mock_client = MagicMock()
mock_client.messages.stream.return_value = iter(["mocked response"])
wrapper = ModelWrapper(client=mock_client)
responses = list(wrapper.get_completion("prompt", "system", []))
assert responses == ["mocked response"]