如何对需要对外部服务进行 API 调用的 Python 类进行单元测试?

问题描述 投票:0回答:1

该类旨在包装远程托管的大型语言模型,并且必须对服务进行 API 调用以获取结果。这是一个例子。


class ModelWrapper(AbstractLLMInterface):
    """The Claude 3 Sonnet model wrapper following the interface."""

    def __init__(
        self,
        region: str = CONFIGS["GCP"]["REGION"],
        project: str = get_gcp_project_id(),
        model: str = CONFIGS["MODELS"]["SONNET_ID"],
    ) -> None:
        """Set up the claude client using the region and project."""
        self.client: AnthropicVertex = AnthropicVertex(region=region, project_id=project)
        self.model_name: str = model
        self.role_key: str = "role"
        self.content_key: str = "content"
        self.user_key: str = "user"
        llm_logger.debug(msg=f"Initialised sonnet client for {region}, {project} and {self.model_name}.")

    def get_completion(self, user_prompt: str, system_prompt: str, history: List[Correspondence]) -> Iterator[str]:
        """
        Fetch a response from the model.This requires an egress request to GCP and the
        service for Anthropic model must be enabled in the VertexAI console.
        """

        # This is where the API call to GCP service happens
        return self.client.messages.stream(user_prompt, system_prompt, history)

我知道用于 python 单元测试的

MagicMocks
是可以配置为返回我想要的任何内容的对象。但在本例中,构造函数的所有参数都是简单字符串,并且
client
是在类内部构建的。所以模拟客户端没有空间,对吧?

这是否意味着任何单元测试都必须进行API调用?或者类设计不正确?任何帮助将不胜感激。

该类应该是可测试的,而无需调用 API。

python unit-testing mocking automated-tests pytest
1个回答
0
投票

我认为最简单的方法是通过构造函数传递客户端,而不是在内部实例化它:

class ModelWrapper(AbstractLLMInterface):
    def __init__(
        self,
        client: Optional[AnthropicVertex] = None,
        region: str = CONFIGS["GCP"]["REGION"],
        project: str = get_gcp_project_id(),
        model: str = CONFIGS["MODELS"]["SONNET_ID"],
    ) -> None:
        if client is None:
            # Build the real client
            client = AnthropicVertex(region=region, project_id=project)
        self.client = client
        self.model_name: str = model

    def get_completion(self, user_prompt: str, system_prompt: str, history: List[Correspondence]) -> Iterator[str]:
        return self.client.messages.stream(user_prompt, system_prompt, history)

然后你可以像这样测试:

from unittest.mock import MagicMock

def test_model_wrapper():
    mock_client = MagicMock()
    mock_client.messages.stream.return_value = iter(["mocked response"])

    wrapper = ModelWrapper(client=mock_client)
    responses = list(wrapper.get_completion("prompt", "system", []))

    assert responses == ["mocked response"]
© www.soinside.com 2019 - 2024. All rights reserved.