我想知道是否可以使用
mpl-scatter-density
或 datashader
创建按密度着色的 1D 散点图,如 here 所示的 2D 情况。
可以用
gaussian_kde
来完成,但是当我要表示的点数在10k以上时,速度相当慢。
另外,是否有一种方法可以在不将轴定义为
mpl-scatter-density
的情况下执行 fig.add_subplot(1, 1, 1, projection='scatter_density')
方法,而只需使用 plt.subplots
创建它们?
我尝试使用
ScatterDensityArtist
中的 mpl_scatter_density
来实现此目的,但没有成功。
这里是使用
gaussian_kde
按密度着色的一维散点图的一些示例代码。
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde
# Generate fake data
data_x = np.broadcast_to(np.array([[1], [2], [3]]), (2, 3, 1000))
data_y = data_x*np.random.normal(size=(2,3,1000))
# Create subplots
nrows = 1
ncols = 2
size = 5
fig, ax_array = plt.subplots(
nrows,
ncols,
figsize=(16/9*ncols*size,nrows*size),
squeeze=False
)
for i,ax_row in enumerate(ax_array):
for j,axes in enumerate(ax_row):
index = nrows*i+j
x = data_x[index,:,:]
y = data_y[index,:,:]
for x_values,y_values in zip(x,y):
z_values = gaussian_kde(y_values)(y_values)
idx = z_values.argsort()
x_values, y_values, z_values = x_values[idx], y_values[idx], z_values[idx]
axes.scatter(
x_values,y_values,
c=z_values, s=10,
cmap=plt.cm.get_cmap('Reds')
)
plt.show()
您可以在更少的点(例如 100 个)上对其进行评估,然后使用这些点进行插值以获得所需的
,而不是尝试在所有
y_values
处评估 KDE(当它由许多点组成时,这会很慢) z_values
。我发现这速度快了 50 倍(当 y_values
有 10000 个样本并且插值器以 100 个点评估 KDE 时)。例如,
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde
# get interp1d
from scipy.interpolate import interp1d
# create interpolator
def interpolated(y_values, npoints=100):
# set y values at which to interpolate
yvals = np.linspace(y_values.min(), y_values.max(), npoints, endpoint=True)
# generate the interpolation function
ifunc = interp1d(yvals, gaussian_kde(y_values)(yvals))
return ifunc(y_values)
# Generate fake data
data_x = np.broadcast_to(np.array([[1], [2], [3]]), (2, 3, 10000))
data_y = data_x*np.random.normal(size=(2,3,10000))
# Create subplots
nrows = 1
ncols = 2
size = 5
fig, ax_array = plt.subplots(
nrows,
ncols,
figsize=(16/9*ncols*size,nrows*size),
squeeze=False
)
cmap = plt.cm.get_cmap('Reds')
for i,ax_row in enumerate(ax_array):
for j,axes in enumerate(ax_row):
index = nrows*i+j
x = data_x[index,:,:]
y = data_y[index,:,:]
for x_values,y_values in zip(x,y):
# use interpolator to get z_values
z_values = interpolated(y_values)
idx = z_values.argsort()
x_values, y_values, z_values = x_values[idx], y_values[idx], z_values[idx]
axes.scatter(
x_values,y_values,
c=z_values, s=10,
cmap=cmap
)
plt.show()
将 subplot_kw 添加到您的 plt.subplots 参数中:
fig, ax_array = plt.subplots(
nrows,
ncols,
figsize=(16/9*ncols*size,nrows*size),
squeeze=False,
subplot_kw={'projection': 'scatter_density'}
)