我从一个大约 21.9 GB 的文件开始。我的计算能力有限,因此我决定将文件分成地理区域(气候分类)。我使用 xarray 读取它们,当我使用 .to_netcdf 导出子文件时,需要很长时间,而且文件大小要大得多 - 高达 300GB。我使用分块来处理文件而不会耗尽内存,但怀疑我做错了什么。我已将代码附加到 .txt 文件中。
.nc 文件说明 尺寸:(时间:~4100,y:385,x:541) 坐标:
我尝试记录内存,在编码时从内存中删除不必要的对象,但我怀疑这与分块或我导出文件的方式有关。
import xarray as xr
import numpy as np
import psutil
import rasterio
import os
import gc
import dask
from scipy.spatial import cKDTree
def log_memory(stage=""):
process = psutil.Process()
memory_used = process.memory_info().rss / 1024 ** 3 # Convert to GB
print(f"[{stage}] Memory usage: {memory_used:.2f} GB", flush=True)
# Paths to files
legend_path = '/home/gridsan/gmiller/climate/legend.txt'
tif_path = '/home/gridsan/gmiller/climate/Beck_KG_V1_present_0p083.tif'
file_path = '/home/gridsan/gmiller/climate/dataset_greece.nc'
# Read legend
legend = {}
with open(legend_path, 'r') as file:
for line in file:
if ':' in line and line.strip()[0].isdigit():
key, rest = line.strip().split(':', 1)
key = int(key)
classification = rest.split('[')[0].strip()
legend[key] = classification
# Read raster data
log_memory("Before reading raster")
with rasterio.open(tif_path) as src:
raster_data = src.read(1) # Read classification band
raster_transform = src.transform
# Extract coordinates
rows, cols = np.indices(raster_data.shape)
lon, lat = rasterio.transform.xy(raster_transform, rows, cols, offset="center")
lon = np.array(lon).flatten()
lat = np.array(lat).flatten()
values = raster_data.flatten()
# Filter valid points
lon_min, lat_min, lon_max, lat_max = 18, 34, 32, 43
mask = (values != 0) & (lon_min <= lon) & (lon <= lon_max) & (lat_min <= lat) & (lat <= lat_max)
lon, lat, values = lon[mask], lat[mask], values[mask]
del raster_data, rows, cols, mask # Free memory
gc.collect()
descriptions = [legend.get(value, "Unknown") for value in values]
log_memory("After reading raster")
# Create KDTree
coords_tree = cKDTree(np.column_stack((lon, lat)))
del lon, lat
log_memory("After creating KDTree")
# Load dataset with chunking to avoid OOM issues
log_memory("Before opening dataset")
ds = xr.open_dataset(file_path, chunks="auto")
ds = ds.unify_chunks()
print(ds.chunks, flush=True)
log_memory("After opening dataset")
# Filter variables with a time dimension
log_memory("Before filtering variables")
time_vars = [var for var in ds.data_vars if 'time' in ds[var].dims]
ds = ds[time_vars]
log_memory("After filtering variables")
# Create land mask using 'ndvi'
log_memory("Before creating land mask")
reference_var = "ndvi"
date_to_use = '2020-06-01T10:00:00.000000000' # Specify the desired date explicitly
# Select the data for the specified date
land_mask = ds[reference_var].sel(time=date_to_use).notnull()
log_memory("After creating land mask")
# Apply land mask lazily
ds = ds.where(land_mask)
log_memory("After applying land mask")
# Generate valid coordinates
x_coords, y_coords = np.meshgrid(ds["x"].values, ds["y"].values)
# Flatten the grids and apply the land mask
land_mask_flat = land_mask.values.flatten()
valid_coords = np.column_stack((
x_coords.flatten()[land_mask_flat],
y_coords.flatten()[land_mask_flat]
))
del x_coords, y_coords
log_memory("After generating valid coordinates")
# Query KDTree
distances, indices = coords_tree.query(valid_coords)
del coords_tree, valid_coords
log_memory("After querying KDTree")
classification_values = values[indices]
del indices, values
classification_descriptions = [legend.get(int(val), "Unknown") for val in classification_values]
log_memory("After classification mapping")
# Assign classifications to dataset
classification_value_data = np.full(land_mask.shape, np.nan)
classification_description_data = np.full(land_mask.shape, np.nan, dtype=object)
classification_value_data[land_mask.values] = classification_values
classification_description_data[land_mask.values] = classification_descriptions
# Add to dataset
ds = ds.assign(
classification_value=(("y", "x"), classification_value_data),
classification_description=(("y", "x"), classification_description_data)
)
log_memory("After assigning classifications")
del classification_value_data, classification_description_data, classification_values, classification_descriptions
gc.collect()
output_dir = "classification_datasets"
os.makedirs(output_dir, exist_ok=True)
excluded_classifications = {6, 7, 9, 14, 15, 18, 19, 25, 26, 27, 29}
unique_classifications = np.unique(ds["classification_value"].values[~np.isnan(ds["classification_value"].values)])
remaining_classifications = [c for c in unique_classifications if c not in excluded_classifications]
# Generate dynamic encoding for all variables
encoding = {}
for var in ds.data_vars:
var_dims = ds[var].dims # Get dimensions of the variable
var_shape = ds[var].shape # Get the shape of the variable
var_chunks = tuple(min(size, 50) for size in var_shape) # Adjust chunk sizes
encoding[var] = {
"zlib": True, # Enable compression
"complevel": 4, # Compression level (1-9, 4 is a good balance)
"chunksizes": var_chunks # Chunk sizes
}
for classification in remaining_classifications:
print(f"Processing classification {classification}...", flush=True)
# Lazy mask application
land_mask = ds["classification_value"] == classification
classification_ds = ds.where(land_mask, drop=True)
classification_ds = classification_ds.chunk({"time": 10}) # Ensure chunking
# Manually create time chunks if chunks are missing
if classification_ds["time"].chunks is None:
time_values = classification_ds["time"].values
time_chunks = np.array_split(time_values, 10) # Split into 10 chunks
for time_chunk in time_chunks:
print(f"Processing time chunk {time_chunk[0]} to {time_chunk[-1]} for classification {classification}...")
chunk_ds = classification_ds.sel(time=slice(time_chunk[0], time_chunk[-1]))
file_name = f"classification_{int(classification)}_time_{time_chunk[0]}.nc"
file_path = os.path.join(output_dir, file_name)
with dask.config.set({"array.slicing.split_large_chunks": True}):
chunk_ds.to_netcdf(
file_path,
compute=True,
engine="netcdf4",
encoding=encoding
)
del chunk_ds
gc.collect()
log_memory(f"After saving time chunk {time_chunk[0]} for classification {classification}")
del classification_ds, land_mask
gc.collect()
log_memory(f"After processing classification {classification}")
print("Processing complete.", flush=True)
较大的文件大小和较长的处理时间可能是由于分块、压缩设置以及数据写入方式效率低下造成的。