如何在 pytorch 中创建正态分布?

问题描述 投票:0回答:7

我想创建一个具有给定均值和标准差的随机正态分布。

python pytorch normal-distribution
7个回答
26
投票

您可以轻松使用 torch.Tensor.normal_() 方法。

让我们创建一个维度为

1 × 5
的矩阵 Z(一维张量),填充来自由
mean = 4
std = 0.5
参数化的正态分布的随机元素样本。

torch.empty(5).normal_(mean=4,std=0.5)

结果:

tensor([4.1450, 4.0104, 4.0228, 4.4689, 3.7810])

14
投票

对于标准正态分布(即

mean=0
variance=1
),您可以使用
torch.randn()

对于自定义

mean
std
的情况,您可以使用
torch.distributions.Normal()


初始化签名:
tdist.Normal(loc,scale,validate_args=None)

文档字符串:
创建参数化的正态(也称为高斯)分布

loc
scale

参数:
loc(浮点或张量):分布的平均值(通常称为 mu)
尺度(浮点或张量):分布的标准偏差 (通常称为西格玛)


这是一个例子:

In [32]: import torch.distributions as tdist

In [33]: n = tdist.Normal(torch.tensor([4.0]), torch.tensor([0.5]))

In [34]: n.sample((2,))
Out[34]: 
tensor([[ 3.6577],
        [ 4.7001]])

11
投票

一个简单的选择是使用基本模块中的

randn
功能。它根据标准高斯分布创建随机样本。要更改平均值和标准差,只需使用加法和乘法即可。下面我根据您请求的分布创建了大小为 5 的样本。

import torch
torch.randn(5) * 0.5 + 4 # tensor([4.1029, 4.5351, 2.8797, 3.1883, 4.3868])

2
投票

您可以按照文档中here所述创建您的发行版。 在您的情况下,这应该是正确的调用,包括从创建的分布中采样:

from torch.distributions import normal

m = normal.Normal(4.0, 0.5)
s = m.sample()

如果你想获得某个尺寸/形状的样本,你可以将其传递给

sample()
,例如

s = m.sample([5, 5])

对于 5x5 张量。


0
投票

这取决于您想要生成什么。

用于生成标准正态分布使用 -

torch.randn()

对于所有分布(例如正态分布、泊松分布或均匀分布等)使用

torch.distributions.Normal()
torch.distribution.Uniform()
。 所有这些方法的详细信息可以在这里看到 - https://pytorch.org/docs/stable/distributions.html#normal

定义这些方法后,您可以使用 .sample 方法来生成实例数。如果分布参数是批量的,它还允许您生成sample_shape形状的样本或sample_shape形状的样本批次。


0
投票

所有发行版请参阅:https://pytorch.org/docs/stable/distributions.html#

单击右侧菜单跳转到正常(或在文档中搜索)。

示例代码:

import torch

num_samples = 3
Din = 1
mu, std = 0, 1
x = torch.distributions.normal.Normal(loc=mu, scale=std).sample((num_samples, Din))

print(x)

有关火炬分布的详细信息(重点是统一),请参阅我的答案:https://stackoverflow.com/a/62919760/1601580


0
投票
import torch
import matplotlib.pyplot as plt



def multivariate_normal_2d(samples):
    mean = torch.zeros(2)
    cov = torch.eye(2)
    dist = torch.distributions.MultivariateNormal(mean, cov)
    return dist.log_prob(samples).exp()

def double_multivariate_normal_2d(samples):
    mean1 = torch.tensor([1.0, 1.0])
    mean2 = torch.tensor([-1.0, -1.0])
    cov = torch.eye(2)
    dist1 = torch.distributions.MultivariateNormal(mean1, cov)
    dist2 = torch.distributions.MultivariateNormal(mean2, cov)
    return 0.5 * (dist1.log_prob(samples).exp() + dist2.log_prob(samples).exp())

class DomainSampler2D:
    def __init__(self, lower_bounds, upper_bounds, prefer_additional_samples=True):
        self.lower_bounds = torch.tensor(lower_bounds)
        self.upper_bounds = torch.tensor(upper_bounds)
        self.dimensions = len(lower_bounds)
        self.prefer_additional_samples = prefer_additional_samples
        self.dimension_ranges = self._calculate_ranges()

    def _calculate_ranges(self):
        if torch.allclose(self.upper_bounds - self.lower_bounds, torch.zeros(self.dimensions)):
            aspect_ratios = torch.zeros(self.dimensions)
            aspect_ratios[0] = 1.0
            return aspect_ratios
        else:
            absolute_ranges = torch.abs(self.upper_bounds - self.lower_bounds)
            return absolute_ranges / torch.sum(absolute_ranges)

    def _calculate_samples_per_dimension(self, total_samples):
        active_dimensions = ~torch.isclose(self.dimension_ranges, torch.zeros(self.dimensions))
        scaling_factor = torch.prod(self.dimension_ranges[active_dimensions])
        relevant_dimensions_count = torch.sum(active_dimensions)
        samples_per_dimension = torch.pow(total_samples / scaling_factor, 1 / relevant_dimensions_count)
        samples_per_dimension = self.dimension_ranges * samples_per_dimension
        return torch.max(samples_per_dimension, torch.ones(samples_per_dimension.shape))

    def _adjust_samples_to_match(self, total_samples, samples_per_dimension):
        integer_samples_per_dimension = torch.round(samples_per_dimension).to(dtype=torch.int)
        current_total = torch.prod(integer_samples_per_dimension)

        if current_total == total_samples:
            return integer_samples_per_dimension

        sample_differences = samples_per_dimension - integer_samples_per_dimension

        if current_total > total_samples and not self.prefer_additional_samples:
            most_over_sampled_dimension = torch.argmin(sample_differences)
            integer_samples_per_dimension[most_over_sampled_dimension] -= 1

        elif current_total < total_samples and self.prefer_additional_samples:
            most_under_sampled_dimension = torch.argmax(sample_differences)
            integer_samples_per_dimension[most_under_sampled_dimension] += 1

        return integer_samples_per_dimension

    def generate_samples(self, total_samples):
        samples_per_dimension = self._calculate_samples_per_dimension(total_samples)
        samples_per_dimension = self._adjust_samples_to_match(total_samples, samples_per_dimension)
        grid_points = [torch.linspace(start, end, count) for start, end, count in zip(self.lower_bounds, self.upper_bounds, samples_per_dimension)]
        mesh = torch.meshgrid(*grid_points, indexing="ij")
        return torch.stack(mesh, dim=-1).reshape(-1, self.dimensions)

def plot_distributions(samples, u_values, v_values):
    samples_numpy = samples.numpy()
    u_numpy = u_values.numpy()
    v_numpy = v_values.numpy()

    fig, axes = plt.subplots(1, 2, figsize=(14, 6))

    scatter1 = axes[0].scatter(samples_numpy[:, 0], samples_numpy[:, 1], c=u_numpy, cmap='viridis')
    axes[0].set_title('Multivariate Normal Distribution')
    axes[0].set_xlabel('x')
    axes[0].set_ylabel('y')
    fig.colorbar(scatter1, ax=axes[0], label='Density')

    scatter2 = axes[1].scatter(samples_numpy[:, 0], samples_numpy[:, 1], c=v_numpy, cmap='plasma')
    axes[1].set_title('Double Multivariate Normal Distribution')
    axes[1].set_xlabel('x')
    axes[1].set_ylabel('y')
    fig.colorbar(scatter2, ax=axes[1], label='Density')

    plt.tight_layout()
    plt.show()


N = 100
sampler = DomainSampler2D([-3.0, -3.0], [3.0, 3.0]).generate_samples(N**2)

plot_distributions(samples, multivariate_normal_2d(samples), double_multivariate_normal_2d(samples))
© www.soinside.com 2019 - 2024. All rights reserved.