我尝试了多种方法,但无法得到我想要的结果。似乎很多问题都使用了整个 @patch 装饰器的不同类型的模拟,但我一直在使用这个更简化的版本。我只是不知道如何模拟一个函数,特别是 scan_all
主课
from databaseClass import scan_all
class JobsDB:
def __init__(self, jobs_table):
self.jobs_table = jobs_table
async def get_all_jobs(self, include_executions=False):
response = await scan_all(self.jobs_table)
if include_executions:
return [JobWithExecutions(**item) for item in response]
return [Job(**item) for item in response]
数据库类
async def scan_all(table, **kwargs):
response = await table.scan(**kwargs)
items = response['Items']
while 'LastEvaluatedKey' in response:
response = await table.scan(**kwargs, ExclusiveStartKey=response['LastEvaluatedKey'])
items = items + response['Items']
return items
测试班
import pytest
from unittest.mock import AsyncMock, MagicMock
from db import JobsDB, Job
@pytest.fixture
def mock_scan_all():
return AsyncMock()
@pytest.fixture
def jobs_table():
return AsyncMock()
@pytest.fixture
def jobs_db():
return JobsDB(AsyncMock())
@pytest.fixture
def job():
return Job(
id='1839e18a-898b-42d1-8747-9b5495dbb0a6',
status='PENDING',
description='Test',
start_date='2024-01-22T13:00:00.000Z',
end_date='2024-01-22T14:00:00.000Z',
@pytest.mark.database
@pytest.mark.asyncio
async def test_get_all_jobs(jobs_db, jobs_table, mock_scan_all, job):
# This used a mock_scan_all fixture in the parameters
mock_scan_all.return_value = [vars(job)]
jobs_table.scan_all = mock_scan_all
# Call the get_all_jobs method
result = await jobs_db.get_all_jobs()
# jobs_db.scan_all.assert_called_once_with(jobs_db.jobs_table)
# Assert that the result is a list containing the expected Job object
assert result == [job]
我尝试过的:
# Didn't work
# jobs_db.jobs_table.scan_all = AsyncMock(return_value=[vars(job)])
# Didn't work
# jobs_db.scan_all.return_value = [vars(job)]
# Didn't work
# scan_all = AsyncMock()
# scan_all.return_value = [vars(job)]
# Didn't work
# mock = AsyncMock()
# mock.scan_all.return_value = [vars(job)]
# Didn't work
# mock_scan_all = AsyncMock()
# mock_scan_all.return_value = [vars(job)]
# jobs_table = MagicMock()
# jobs_table.scan_all = mock_scan_all
# Didn't work
# jobs_table.scan.return_value = {'Items': [vars(job) for job in jobs]}
我不断收到一条失败消息,指出右侧包含的项目多于左侧。左边什么都没有。
断言 [] == [作业]
我还尝试断言它是否被调用过一次并且总是返回 0 的结果。
添加这个似乎给了我结果并解决了所有问题。
@pytest.fixture
def mock_scan_all():
with patch('module.jobs_db.scan_all') as mock:
yield mock
async def test_get_all_jobs(jobs_db, jobs_table, mock_scan_all, job):
mock_scan_all.return_value = [job.dict()]
result = await jobs_db.get_all_jobs(include_executions=False)
assert result == [job]