我正在尝试为以下Python代码编写集成测试:
import xx.settings.config as stg
from xx.infrastructure.utils import csvReader, dataframeWriter
from pyspark.sql import SparkSession
from typing import List
from awsglue.utils import getResolvedOptions
import sys
def main(argv: List[str]) -> None:
args = getResolvedOptions(
argv,
['JOB_NAME', 'S3_BRONZE_BUCKET_NAME', 'S3_PRE_SILVER_BUCKET_NAME', 'S3_BRONZE_PATH', 'S3_PRE_SILVER_PATH'],
)
s3_bronze_bucket_name = args['S3_BRONZE_BUCKET_NAME']
s3_pre_silver_bucket_name = args['S3_PRE_SILVER_BUCKET_NAME']
s3_bronze_path = args['S3_BRONZE_PATH']
s3_pre_silver_path = args['S3_PRE_SILVER_PATH']
spark = SparkSession.builder.getOrCreate()
spark.conf.set('spark.sql.sources.partitionOverwriteMode', 'dynamic')
for table in list(stg.data_schema.keys()):
raw_data = stg.data_schema[table].columns.to_dict()
df = csvReader(spark, s3_bronze_bucket_name, s3_bronze_path, table, schema, '\t')
dataframeWriter(df, s3_pre_silver_bucket_name, s3_pre_silver_path, table, stg.data_schema[table].partitionKey)
if __name__ == '__main__':
main(sys.argv)
我基本上循环一个表列表,然后从 S3 读取它们的内容(csv 格式),并在 S3 中以 parquet 格式写入它们。
这些是 csvReader 和 dataframeWriter 的定义:
def csvReader(spark: SparkSession, bucket: str, path: str, table: str, schema: StructType, sep: str) -> DataFrame:
return (
spark.read.format('csv')
.option('header', 'true')
.option('sep', sep)
.schema(schema)
.load(f's3a://{bucket}/{path}/{table}.csv')
)
def dataframeWriter(df: DataFrame, bucket: str, path: str, table: str, partition_key: str) -> None:
df.write.partitionBy(partition_key).mode('overwrite').parquet(f's3a://{bucket}/{path}/{table}/')
对于我的集成测试,我想将 S3 交互替换为 本地文件交互(从本地读取css并在本地写入parquet。这就是我所做的:
import os
from unittest import TestCase
from unittest.mock import patch, Mock
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType
import xx.application.perfmarket_pre_silver as perfmarket_pre_silver
from dvr_config_utils.config import initialize_settings
def local_csvReader(spark: SparkSession, table: str, schema: StructType, sep: str):
"""Mocked function that replaces real csvReader. this one reads from local rather than S3."""
return (
spark.read.format('csv')
.option('header', 'true')
.option('sep', sep)
.schema(schema)
.load(f'../input_mock/{table}.csv')
)
def local_dataframeWriter(df, table: str, partition_key: str):
"""Mocked function that replaces real dataframeWriter. this one writes in local rather than S3."""
output_dir = f'../output_mock/{table}/'
if not os.path.exists(output_dir):
os.makedirs(output_dir)
df.write.partitionBy(partition_key).mode('overwrite').parquet(output_dir)
class TestPerfmarketSilver(TestCase):
@classmethod
def setUpClass(cls):
cls.spark = SparkSession.builder.master('local').appName('TestPerfmarketSilver').getOrCreate()
cls.spark.conf.set('spark.sql.sources.partitionOverwriteMode', 'dynamic')
@classmethod
def tearDownClass(cls):
"""Clean up the Spark session and test data."""
cls.spark.stop()
os.system('rm -rf ../output_mock')
@patch('xx.application.cc.getResolvedOptions')
@patch('src.xx.infrastructure.utils.csvReader', side_effect=local_csvReader)
@patch('xx.infrastructure.utils.dataframeWriter', side_effect=local_dataframeWriter)
def test_main(self, mock_csvreader, mock_datawriter, mocked_get_resolved_options: Mock):
expected_results = {'chemins': {'nbRows': 8}}
mocked_get_resolved_options.return_value = {
'JOB_NAME': 'perfmarket_pre_silver_test',
'S3_BRONZE_BUCKET_NAME': 'test_bronze',
'S3_PRE_SILVER_BUCKET_NAME': 'test_pre_silver',
'S3_BRONZE_PATH': '../input_mock',
'S3_PRE_SILVER_PATH': '../output_mock'
}
perfmarket_pre_silver.main([])
for table in stg.data_schema.keys():
# Verify that the output Parquet file is created
output_path = f'../output_mock/{table}/'
self.assertTrue(os.path.exists(output_path))
# Read the written Parquet file and check the data
written_df = self.spark.read.parquet(output_path)
self.assertEqual(written_df.count(), expected_results[table]['nbRows']) # Check row count
self.assertTrue(
[
column_data['bronze_name']
for table in stg.data_schema.values()
for column_data in table['columns'].values()
]
== written_df.columns
)
我想用这两行做什么:
@patch('src.xx.infrastructure.utils.csvReader', side_effect=local_csvReader)
@patch('xx.infrastructure.utils.dataframeWriter', side_effect=local_dataframeWriter)
是将csvReader的定义替换为local_csvReader,将dataframeWriter的定义替换为local_dataframeWriter。
不幸的是,代码正在重新调整
py4j.protocol.Py4JJavaError: An error occurred while calling o39.load.
: java.lang.RuntimeException: java.lang.ClassNotFoundException: Class org.apache.hadoop.fs.s3a.S3AFileSystem not found
错误指向主代码中的 csvReader 调用(第一个片段)。
所以我的替换技术显然不起作用。
请问我做错了什么?
将 s3a:// URL 替换为 file:// URL,spark 将只读取/写入本地文件系统。
尝试模拟 S3 语义是一项非常复杂的工作,我强烈建议不要这样做。如果您确实必须遵循该路径,请在本地部署 minio 或尝试 Adobe 的 S3 模拟:https://github.com/adobe/S3Mock。它具有 S3A 连接器支持的所有 S3 操作。至少我认为...我们只针对真正的 S3 端点测试 s3a 代码:S3 标准、S3 Express、google gcs(!)和一些第三方商店。稍微慢一点,但在发现回归和怪癖方面要好得多。