如何在 Matplotlib 中创建一个看起来像群图但有重叠点的图?

问题描述 投票:0回答:1

我正在尝试创建一个(某种)群图 - 它应该清楚地显示分布的形状,但允许通过重叠数据点的点表示来快速绘制数以万计的数据点。比如这样:

enter image description here

我的想法是本质上创建一个点图,但将每个分布划分为分位数,并将抖动应用于数据点的水平位置,其大小与给定四分位数中的点数成正比。当分布大小相同时,这种方法效果很好,但我需要某种方法来缩放抖动,以便当其中一个分布只有几个数据点时,代表它们的点将排列在一条(几乎)垂直线上,即不像下面那样:

enter image description here

这是我的绘图创建代码:

import matplotlib.pyplot as plt
import numpy as np


def fancy_distribution_plot(distributions: list, tick_labels: list, max_plot_width: int = 1, alpha=0.7,
                            number_of_segments=12,
                            separation_between_plots=0.1,
                            separation_between_subplots=0.1,
                            vertical_limits=None,
                            grid=False,
                            remove_outlier_above_segment=None,
                            remove_outlier_below_segment=None,
                            y_label=None,
                            title=None):
    fig, ax = plt.subplots()

    number_of_plots = len(distributions)
    # print(f" number of plots {number_of_plots}")
    # print(f" max x line {number_of_plots * (max_plot_wwidth + separation_between_plots) + separation_between_plots}")

    ax.set_xlim(left=0, right=number_of_plots * (max_plot_width + separation_between_plots) + separation_between_plots)

    ticks = [separation_between_plots + max_plot_width / 2 + (max_plot_width + separation_between_plots) * i
             for i in range(0, number_of_plots)]
    print(ticks)

    for i in range(len(distributions)):
        distribution = distributions[i]
        # print(f"distribution {distribution}")
        segments = np.linspace(np.min(distribution), np.max(distribution), number_of_segments + 1)[1:-1]
        # print(f"segments {segments}")
        segment_indices = number_of_segments - 1 - np.where(segments[:, None] >= distribution[None, :], 1, 0).sum(0)
        # print(f"quantile indices {segment_indices}")
        if remove_outlier_above_segment:
            a = remove_outlier_above_segment[i]
            distribution = distribution[segment_indices <= a]
            segment_indices = segment_indices[segment_indices <= a]

        if remove_outlier_below_segment:
            b = remove_outlier_below_segment[i]
            distribution = distribution[segment_indices >= b - 1]
            segment_indices = segment_indices[segment_indices >= b - 1]

        values, counts = np.unique(segment_indices, return_counts=True)
        # print(f"values {values}")
        # print(f"counts {counts}")
        counts_filled = []
        j = 0
        for k in range(number_of_segments):
            if k in values:
                counts_filled.append(counts[j])
                j += 1
            else:
                counts_filled.append(0)
        variances = (max_plot_width / 2) * (counts_filled / np.max(counts))
        # print(f"variances {variances}")
        jitter_unadjusted = np.random.uniform(-1, 1, len(distribution))
        jitter = np.take(variances, segment_indices) * jitter_unadjusted
        # print(f"jitter {jitter}")
        ax.scatter(jitter + ticks[i], distribution, alpha=alpha)

    ax.set_xticks(ticks)
    ax.set_xticklabels(tick_labels)
    if vertical_limits:
        ax.set_ylim(bottom=vertical_limits[0], top=vertical_limits[1])
    if not grid:
        ax.grid(False)
    if y_label:
        ax.set_ylabel(y_label)
    if title:
        ax.set_title(title)
    plt.show()

以及重新创建上面第二个图表的代码:

# Create example random data
np.random.seed(0)
distro1 = np.random.normal(0, 2, 4)
distr2 = np.random.normal(1, 1, 10)
distr3 = np.random.normal(2, 3, 1000)

distributions = [distro1, distr2, distr3]
fancy_distribution_plot(distributions, tick_labels=['distro1', 'distro2', 'distro3'], number_of_segments=100,
                        grid=False)

python matplotlib visualization
1个回答
0
投票

扩展我的评论,您可以缩放方差(以及抖动)除以所有分布中的最大值

count

一个可能的实现(从你的函数开始)是:

import matplotlib.pyplot as plt
import numpy as np


def fancy_distribution_plot(distributions: list, tick_labels: list, max_plot_width: int = 1, alpha=0.7,
                            number_of_segments=12,
                            separation_between_plots=0.1,
                            separation_between_subplots=0.1,
                            vertical_limits=None,
                            grid=False,
                            remove_outlier_above_segment=None,
                            remove_outlier_below_segment=None,
                            y_label=None,
                            title=None):
    fig, ax = plt.subplots()

    number_of_plots = len(distributions)

    ax.set_xlim(left=0, right=number_of_plots * (max_plot_width + separation_between_plots) + separation_between_plots)

    ticks = [separation_between_plots + max_plot_width / 2 + (max_plot_width + separation_between_plots) * i
             for i in range(0, number_of_plots)]
    
    max_counts = 0.0
    counts_filled_list = []
    segment_indices_list = []
    for i in range(len(distributions)):
        distribution = distributions[i]
        
        segments = np.linspace(np.min(distribution), np.max(distribution), number_of_segments + 1)[1:-1]
        
        segment_indices = number_of_segments - 1 - np.where(segments[:, None] >= distribution[None, :], 1, 0).sum(0)
        
        if remove_outlier_above_segment:
            a = remove_outlier_above_segment[i]
            distribution = distribution[segment_indices <= a]
            segment_indices = segment_indices[segment_indices <= a]

        if remove_outlier_below_segment:
            b = remove_outlier_below_segment[i]
            distribution = distribution[segment_indices >= b - 1]
            segment_indices = segment_indices[segment_indices >= b - 1]
        segment_indices_list.append(segment_indices)

        values, counts = np.unique(segment_indices, return_counts=True)
        if np.max(counts) > max_counts:
            max_counts = np.max(counts)
        counts_filled = []
        j = 0
        for k in range(number_of_segments):
            if k in values:
                counts_filled.append(counts[j])
                j += 1
            else:
                counts_filled.append(0)
        counts_filled_list.append(counts_filled)

    for i in range(len(distributions)):    
        #print(f"counts filled {counts_filled}")
        variances = (max_plot_width / 2) * (counts_filled_list[i] / max_counts)
        #print(f"variances {variances}")
        jitter_unadjusted = np.random.uniform(-1, 1, len(distributions[i])) 
        jitter = np.take(variances, segment_indices_list[i]) * jitter_unadjusted

        # print(f"jitter {jitter}")
        ax.scatter(jitter + ticks[i], distributions[i], alpha=alpha)

    ax.set_xticks(ticks)
    ax.set_xticklabels(tick_labels)
    if vertical_limits:
        ax.set_ylim(bottom=vertical_limits[0], top=vertical_limits[1])
    if not grid:
        ax.grid(False)
    if y_label:
        ax.set_ylabel(y_label)
    if title:
        ax.set_title(title)
    plt.show()

从你的玩具示例中的数据可以看出

Swarmplots

代码非常混乱,重复 for 循环既不优雅也不高效:我希望至少结果是您想要的。

© www.soinside.com 2019 - 2024. All rights reserved.