当值来自函数 python unittest 时如何模拟全局变量

问题描述 投票:0回答:1

我必须在Python中模拟全局变量,但变量值来自另一个函数。当我导入文件时,这个函数正在运行,但我想要的是模拟值。

秘密.py

import traceback
import logging
import boto3
import os
import json


logger = logging.getLogger()
logger.setLevel(logging.INFO)

secret_name = os.environ['SECRETS_NAME']
region_name = os.environ['AWS_REGION']
config_secret = dict()

def init_config_secret():

    session = boto3.session.Session()
    client = session.client(
        service_name='secretsmanager',
        region_name=region_name
    )

    try:

        config_secret_value = client.get_secret_value(SecretId=secret_name)

        if 'SecretString' in config_secret_value:
            global config_secret
            config_secret = json.loads(config_secret_value['SecretString'])

        return config_secret
    except Exception as e:
        logger.error('Error while retrieving secrets')
        traceback.print_exc()


def get_config_secret():
    if not bool(config_secret):
        return init_config_secret()
    else:
        return config_secret


if __name__ == '__main__':
    get_config_secret()

request_auth.py

import requests
from secrets import get_config_secret
from datetime import datetime, timedelta

config = get_config_secret()

print(f'config: {config}')

token = None

def verifySSL() :
    return True


def get_api_root_url():
    return config["url"]


def build_request_header(access_token, content_type):

    return {
        "authorization" : access_token,
        "Content-Type" : content_type
    }


def get_access_token():
    global token
    if token is None:
        acquire_token()
    elif token["expiry"] < datetime.now():
        acquire_token()

    return token["authToken"]


def acquire_token():
    auth_payload = {
        "grant_type" : "client_credentials",
        "client_id" : config["id"]
    }
    response = requests.post(config["url"], data = auth_payload, verify = verifySSL())
    # print(response.json())

    global token
    token= dict()
    token["authToken"] = response.json()["access_token"]
    token["expiry"] = datetime.now() + timedelta(minutes = 10)

test_request_auth.py

import json, sys
import unittest, os
from unittest.mock import patch, Mock
from unittest import mock


from request_auth import *  


class TestRequestAuth(unittest.TestCase):

    @patch("request_auth.config", {"test": "test"})
    @patch("request_auth.get_config_secret")
    def test_get_api_root_url(self, mock_get_config_secret):
        mock_get_config_secret.return_value = {"url": "test"}
        self.assertEqual(get_api_root_url(), "https://test.test.com/")

我尝试过这种方法,但不知道需要做什么。 有人可以帮忙吗?

python python-unittest python-unittest.mock
1个回答
0
投票

您在导入过程中将全局变量加载到内存后对其进行修补。应用补丁后,您应该使用

importlib
重新加载模块。

import importlib
import unittest
from unittest.mock import patch

from request_auth import *  


class TestRequestAuth(unittest.TestCase):

    @patch("request_auth.get_config_secret")
    def test_get_api_root_url(self, mock_get_config_secret):
        mock_get_config_secret.return_value = {"url": "test"}

        importlib.reload(module=request_auth)

        self.assertEqual(get_api_root_url(), "https://test.test.com/")
© www.soinside.com 2019 - 2024. All rights reserved.