python ThreadPoolExecutor 或 ProcessPoolExecutor 内存不足

问题描述 投票:0回答:1

我写了一个接口,主要是用来查询数据,处理后返回到前端数据文件。在查询数据的过程中,我使用了多线程,处理数据则使用了多进程。

我的服务部署在K8S上,用户请求的数据量有时很大(可能几万到几百万),多次请求后容器的内存经常超过10G。 10G是我给容器加了限制,容器自动重启有时会出现下面的错误信息

Concurrent future. Process.

**BrokenProcessPool**: a process in process pool was terminated abruptly while the future was running or pending. A process in process pool was terminated abruptly while the future was running or pending.

我尝试过

gc.collect
、分块并将数据写入文件,如下面的代码所示。我应该如何优化来解决这个问题?

@download_bp.route('/dataDownload', methods=['POST'])
@siwa.doc(body=GetChartDataBody, tags=['数据下载'], summary='数据下载')
def download_file(body: GetChartDataBody):
    file_path = get_chart_data2file(body)['data']

    @after_this_request
    def remove_file(response):
        try:
            os.remove(file_path)
        except Exception as error:
            logger.error("Error removing or closing downloaded file handle", error)
        return response

    return send_file(file_path, mimetype='application/csv')

import os
import gc
import time
from flask import request
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
from functools import partial
from ..utils import get_login_user
from ..log import logger
from ..dtos.request_body import GetChartDataBody
from ..services.get_vars import conn_influxdb
import pandas as pd
from ..dtos.resp_result import RespResult
from ..exception import BizException
from ..settings import settings
from datetime import datetime, timedelta
from ..services import pd_data_nan2none
from . import user_file_path
import pickle
import uuid
import shutil


def get_chart_data2file(body: GetChartDataBody):
  """
  多线程查询,多进程处理
  :param body:
  :return:
  """
  gc.collect()
  terminal_id = body.terminal_id
  vars_list = body.vars_list
  start_time_str = body.start_time
  end_time_str = body.end_time
  table_name = terminal_id + '_xcp'
  record_time_start = time.time()
  user_name = get_login_user(request)
  user_directory = os.path.join(user_file_path, f"{user_name}_{str(uuid.uuid4())}")
  os.makedirs(user_directory, exist_ok=True)
  result_df = gen_result_df(user_directory, table_name, vars_list, start_time_str, end_time_str, True)
  dtype_dict = {col: 'float32' for col in result_df.columns if col != 'timestamps'}
  dtype_dict['timestamps'] = result_df['timestamps'].dtype
  result_df = result_df.astype(dtype_dict)
  file_path = os.path.join(user_file_path, settings.FILE_NAME)
  result_df.to_csv(file_path, index=False, chunksize=10000)
  record_time_end = time.time()
  spend_time = record_time_end - record_time_start
  logger.info(f'{user_name} download all finished: terminal_id:{terminal_id}, query_vars:{vars_list}, query_time_range:{start_time_str}至{end_time_str},spend_time: {spend_time}s')
  del result_df
  gc.collect()
  return RespResult.success(file_path)


def gen_result_df(user_directory, table_name, vars_list, start_time_str, end_time_str, is_download=False):
  """
  生成处理完的df
  :param user_directory
  :param table_name:
  :param vars_list:
  :param start_time_str:
  :param end_time_str:
  :param is_download:   降频or 不降
  :return:
  """
  logger.info(f'start query and handle data:{time.time()} ')
  time_groups = gen_time_group(start_time_str, end_time_str)
  with ThreadPoolExecutor(max_workers=3) as executor:
    futures = [executor.submit(partial(query_database, user_directory), vars_list, table_name, start_time, end_time) for start_time, end_time in time_groups]
    results = [future.result() for future in futures]
    del futures
    gc.collect()

  if not any(results):
    shutil.rmtree(user_directory)
    raise BizException(201, '该时间范围内无数据,请重新选择时间')

  file_paths = [os.path.join(user_directory, file) for file in os.listdir(user_directory)]

  with ProcessPoolExecutor(max_workers=2) as executor:
    start = 0
    cnt = len(results)
    df_chunks = []
    while start < cnt:
      df_chunks.extend(executor.map(partial(handle_data, is_download=is_download), file_paths[start:start + 3]))
      start += 3
      gc.collect()

  result_df = pd.concat(df_chunks, ignore_index=True)
  result_df.sort_values(by='timestamps', ascending=True, inplace=True)
  shutil.rmtree(user_directory)
  del results, time_groups, df_chunks
  gc.collect()

  return result_df


def gen_time_group(start_time_str, end_time_str):
  """
  将时间分块,按时间多线程分块查询
  :param start_time_str:
  :param end_time_str:
  :return:
  """
  start_time = datetime.strptime(start_time_str, '%Y-%m-%d %H:%M:%S')
  end_time = datetime.strptime(end_time_str, '%Y-%m-%d %H:%M:%S')
  day = timedelta(days=0.5)
  time_ranges = []
  current_time = start_time
  while current_time < end_time:
    if current_time + day <= end_time:
      time_ranges.append([current_time.strftime('%Y-%m-%d %H:%M:%S'), (current_time + day).strftime('%Y-%m-%d %H:%M:%S')])
    else:
      time_ranges.append([current_time.strftime('%Y-%m-%d %H:%M:%S'), end_time.strftime('%Y-%m-%d %H:%M:%S')])
    current_time += day
  return time_ranges


def query_database(user_directory, vars_list_group, table_name, start_time, end_time):
  """
  查询数据
  :param user_directory
  :param vars_list_group
  :param table_name:
  :param start_time:
  :param end_time:
  :return:
  """
  query = f"SELECT {','.join(vars_list_group)} from \"{table_name}\" where time >= $start_time and time < $end_time"
  params = {
    "start_time": start_time,
    "end_time": end_time
  }
  client = conn_influxdb()
  result = client.query(query, bind_params=params)
  print(params)
  client.close()
  logger.info(f'finish query:{time.time()}')
  if result:
    file_name = os.path.join(user_directory, f"data_{start_time.split(' ')[0]}_{start_time.split(' ')[1][:2]}.pkl")
    with open(file_name, 'wb') as f:
      pickle.dump(result, f)
    del result
    gc.collect()
    return 1
  else:
    return None


def handle_data(file_path, is_download):
  try:
    with open(file_path, "rb") as f:
      result = pickle.load(f)
    data_df = pd.DataFrame(list(result.get_points()))
    time_adjust = (pd.to_datetime(data_df['time']) - pd.Timedelta(hours=8)).astype('int64')
    if is_download:
      data_df['timestamps'] = time_adjust // 10 ** 9
      data_df.drop_duplicates(subset='timestamps', keep='last', inplace=True)
      data_df.dropna(axis=1, how='all', inplace=True)
    else:
      data_df['timestamps'] = time_adjust // 10 ** 6

    data_df.drop(['time'], axis=1, inplace=True)
    data_df = data_df.apply(pd.to_numeric, errors='coerce').round(3)
    if not is_download:
      data_df = pd_data_nan2none(data_df)
    del result, time_adjust
    gc.collect()
    logger.info(f'finish handle:{time.time()}')
    return data_df
  except Exception as e:
    gc.collect()
    logger.error(f'handle_data function error: {str(e)}')

memory status

python python-3.x out-of-memory threadpoolexecutor
1个回答
0
投票

将数据写入文件时进行分块是没有意义的,因为您最终仍然会得到一个非常大的内存数据帧

result_df

一个想法是将较小的、可能未排序的块写入单独的 CSV,然后使用标准 UNIX

sort
实用程序将块的行合并到最终输出文件中,因此完整的数据帧永远不需要在内存中;
sort
本身往往非常高效(当然取决于实现)来完成此类事情,如果您使用 ASCII 语言环境,则更是如此。

假设您的时间戳可按字典顺序排序(例如 ISO8601 时间戳),总体思路是:

$ cat > a.csv
2024-01-01
2024-04-33
2022-02-55
$ cat > b.csv
2023-01-33
2022-01-11
$ sort a.csv b.csv
2022-01-11
2022-02-55
2023-01-33
2024-01-01
2024-04-33
$
© www.soinside.com 2019 - 2024. All rights reserved.