我正在尝试创建一个(某种)群图 - 它应该清楚地显示分布的形状,但允许通过重叠数据点的点表示来快速绘制数以万计的数据点。比如这样:
我的想法是本质上创建一个点图,但将每个分布划分为分位数,并将抖动应用于数据点的水平位置,其大小与给定四分位数中的点数成正比。当分布大小相同时,这种方法效果很好,但我需要某种方法来缩放抖动,以便当其中一个分布只有几个数据点时,代表它们的点将排列在一条(几乎)垂直线上,即不像下面那样:
这是我的绘图创建代码:
import matplotlib.pyplot as plt
import numpy as np
def fancy_distribution_plot(distributions: list, tick_labels: list, max_plot_width: int = 1, alpha=0.7,
number_of_segments=12,
separation_between_plots=0.1,
separation_between_subplots=0.1,
vertical_limits=None,
grid=False,
remove_outlier_above_segment=None,
remove_outlier_below_segment=None,
y_label=None,
title=None):
fig, ax = plt.subplots()
number_of_plots = len(distributions)
# print(f" number of plots {number_of_plots}")
# print(f" max x line {number_of_plots * (max_plot_wwidth + separation_between_plots) + separation_between_plots}")
ax.set_xlim(left=0, right=number_of_plots * (max_plot_width + separation_between_plots) + separation_between_plots)
ticks = [separation_between_plots + max_plot_width / 2 + (max_plot_width + separation_between_plots) * i
for i in range(0, number_of_plots)]
print(ticks)
for i in range(len(distributions)):
distribution = distributions[i]
# print(f"distribution {distribution}")
segments = np.linspace(np.min(distribution), np.max(distribution), number_of_segments + 1)[1:-1]
# print(f"segments {segments}")
segment_indices = number_of_segments - 1 - np.where(segments[:, None] >= distribution[None, :], 1, 0).sum(0)
# print(f"quantile indices {segment_indices}")
if remove_outlier_above_segment:
a = remove_outlier_above_segment[i]
distribution = distribution[segment_indices <= a]
segment_indices = segment_indices[segment_indices <= a]
if remove_outlier_below_segment:
b = remove_outlier_below_segment[i]
distribution = distribution[segment_indices >= b - 1]
segment_indices = segment_indices[segment_indices >= b - 1]
values, counts = np.unique(segment_indices, return_counts=True)
# print(f"values {values}")
# print(f"counts {counts}")
counts_filled = []
j = 0
for k in range(number_of_segments):
if k in values:
counts_filled.append(counts[j])
j += 1
else:
counts_filled.append(0)
variances = (max_plot_width / 2) * (counts_filled / np.max(counts))
# print(f"variances {variances}")
jitter_unadjusted = np.random.uniform(-1, 1, len(distribution))
jitter = np.take(variances, segment_indices) * jitter_unadjusted
# print(f"jitter {jitter}")
ax.scatter(jitter + ticks[i], distribution, alpha=alpha)
ax.set_xticks(ticks)
ax.set_xticklabels(tick_labels)
if vertical_limits:
ax.set_ylim(bottom=vertical_limits[0], top=vertical_limits[1])
if not grid:
ax.grid(False)
if y_label:
ax.set_ylabel(y_label)
if title:
ax.set_title(title)
plt.show()
以及重新创建上面第二个图表的代码:
# Create example random data
np.random.seed(0)
distro1 = np.random.normal(0, 2, 4)
distr2 = np.random.normal(1, 1, 10)
distr3 = np.random.normal(2, 3, 1000)
distributions = [distro1, distr2, distr3]
fancy_distribution_plot(distributions, tick_labels=['distro1', 'distro2', 'distro3'], number_of_segments=100,
grid=False)
扩展我的评论,您可以缩放方差(以及抖动)除以所有分布中的最大值
count
。
一个可能的实现(从你的函数开始)是:
import matplotlib.pyplot as plt
import numpy as np
def fancy_distribution_plot(distributions: list, tick_labels: list, max_plot_width: int = 1, alpha=0.7,
number_of_segments=12,
separation_between_plots=0.1,
separation_between_subplots=0.1,
vertical_limits=None,
grid=False,
remove_outlier_above_segment=None,
remove_outlier_below_segment=None,
y_label=None,
title=None):
fig, ax = plt.subplots()
number_of_plots = len(distributions)
ax.set_xlim(left=0, right=number_of_plots * (max_plot_width + separation_between_plots) + separation_between_plots)
ticks = [separation_between_plots + max_plot_width / 2 + (max_plot_width + separation_between_plots) * i
for i in range(0, number_of_plots)]
max_counts = 0.0
counts_filled_list = []
segment_indices_list = []
for i in range(len(distributions)):
distribution = distributions[i]
segments = np.linspace(np.min(distribution), np.max(distribution), number_of_segments + 1)[1:-1]
segment_indices = number_of_segments - 1 - np.where(segments[:, None] >= distribution[None, :], 1, 0).sum(0)
if remove_outlier_above_segment:
a = remove_outlier_above_segment[i]
distribution = distribution[segment_indices <= a]
segment_indices = segment_indices[segment_indices <= a]
if remove_outlier_below_segment:
b = remove_outlier_below_segment[i]
distribution = distribution[segment_indices >= b - 1]
segment_indices = segment_indices[segment_indices >= b - 1]
segment_indices_list.append(segment_indices)
values, counts = np.unique(segment_indices, return_counts=True)
if np.max(counts) > max_counts:
max_counts = np.max(counts)
counts_filled = []
j = 0
for k in range(number_of_segments):
if k in values:
counts_filled.append(counts[j])
j += 1
else:
counts_filled.append(0)
counts_filled_list.append(counts_filled)
for i in range(len(distributions)):
#print(f"counts filled {counts_filled}")
variances = (max_plot_width / 2) * (counts_filled_list[i] / max_counts)
#print(f"variances {variances}")
jitter_unadjusted = np.random.uniform(-1, 1, len(distributions[i]))
jitter = np.take(variances, segment_indices_list[i]) * jitter_unadjusted
# print(f"jitter {jitter}")
ax.scatter(jitter + ticks[i], distributions[i], alpha=alpha)
ax.set_xticks(ticks)
ax.set_xticklabels(tick_labels)
if vertical_limits:
ax.set_ylim(bottom=vertical_limits[0], top=vertical_limits[1])
if not grid:
ax.grid(False)
if y_label:
ax.set_ylabel(y_label)
if title:
ax.set_title(title)
plt.show()
从你的玩具示例中的数据可以看出
代码非常混乱,重复 for 循环既不优雅也不高效:我希望至少结果是您想要的。