Python 包:Dirichlet 分布的 MLE

问题描述 投票:0回答:2

我想知道是否有人知道一个Python包,它实现了MLE来估计狄利克雷分布的参数。

python statistics mle dirichlet
2个回答
8
投票

Eric Suh 有一个包裹这里

$ pip install git+https://github.com/ericsuh/dirichlet.git

然后:

import numpy
import dirichlet
a0 = numpy.array([100, 299, 100])
D0 = numpy.random.dirichlet(a0, 1000)
dirichlet.mle(D0)

0
投票

无论如何都晚了,但万一这仍然可以帮助那里的任何人......

from sklearn.base import BaseEstimator, RegressorMixin
from scipy.special import polygamma
import pandas as pd
import numpy as np


class DirichletMLE(BaseEstimator, RegressorMixin):
    '''
    Implements a static Dirichlet regressor compliant with the Sklearn API
    providing the MLE estimates through the `fit` method.

    Parameters
    ----------
    `eps: float = 0.00001`
        An optional error tolerance parameter for the numerical solver.
        Defaults to `0.00001`.
    `maxiter: int = 100`
        An optional fail-safe parameter limiting the number of iterations the
        numerical solver can perform. Defaults to `100`.

    Raises
    ------
    `AssertionError`
        If the error tolerance is non-positive.
    `AssertionError`
        If the maximum number of iterations is non-positive.
    '''

    def __init__(self, eps: float = 0.00001, maxiter: int = 100) -> None:
        assert eps > 0.0, 'Error tolerance must be positive'
        assert maxiter > 0, 'Maximum number of iterations must be positive'

        self.eps = eps
        self.maxiter = maxiter
        self.ndim = None
        self.niter_ = None
        self.status = None
        self.shape = None
        self.mean_array = None

    def __str__(self) -> str:
        return str({
            'shape': None if self.shape is None else np.round(self.shape, 4),
            'niter': self.niter, 'status': self.status
        })

    def fit(self, y: pd.DataFrame):
        ''' \
        Calibrates the regressor by finding the maximum likelihood estimates
        by way of Newton's method.

        Parameters
    ----------
    `y: pd.DataFrame`
        A response matrix.

    Raises
    ------
    `AssertionError`
        If one or more response values are negative.
    `AssertionError`
        If one or more response arrays does not sum to 100%.
    '''
    assert y.min() >= 0.0, 'The response matrix is out of bounds'
    assert all((y.sum(axis=1) - 1.0) <= self.eps), 'The row-wise sum of the response must equal 100%'  # noqa: E501

    ndim = y.shape[1]
    niter = 0
    check = False
    logy_mean = np.log(y).mean(axis=0)

    a = np.zeros(ndim)
    g = np.zeros(ndim)
    h = np.zeros((ndim, ndim))

    # Method of moments seeds
    mean = y.mean(axis=0)
    phi = (mean * (1.0 - mean)).sum() / ((y - mean) ** 2.0).sum(axis=1).mean()  # noqa: E501
    a = mean * phi

    while not (check or niter > self.maxiter):
        niter += 1

        g = polygamma(n=0, x=a.sum()) * np.ones(ndim) + logy_mean - polygamma(n=0, x=a)  # noqa: E501
        h = polygamma(n=1, x=a.sum()) * np.ones((ndim, ndim)) - np.diag(polygamma(n=1, x=a))  # noqa: E501

        if np.linalg.det(a=h) == 0.0:
            break

        delta = -np.linalg.inv(a=h) @ g
        check = np.linalg.norm(x=delta) <= self.eps
        a += delta

    self.ndim_ = ndim
    self.shape = a
    self.mean_array = self.shape / self.shape.sum()
    self.niter = niter
    self.status = int(check)

    return self

def predict(self, X: pd.DataFrame | None = None) -> np.ndarray:
    ''' \
    Predicts the mean response array (decision rule consistent with the use
    of quadratic loss function).

    Parameters
    ----------
    `X: pd.DataFrame | None = None`
        An optional dummy feature matrix to provide compliance with the
        Sklearn API.

    Returns
    -------
    `np.ndarray`
        An array of prediction.
    '''
    assert self.status == 1, 'Regressor must be calibrated prior to predicting'  # noqa: E501

    if X is not None:
        return self.mean_array * np.ones((len(X), self.ndim))

    return self.mean_array


if __name__ == '__main__':
    size = 10_000
    ndim = 10
    shape = np.random.uniform(low=0.0, high=10.0, size=ndim)
    y = np.random.dirichlet(alpha=shape, size=size)
    reg = DirichletMLE().fit(y=y)

    print({'shape': shape.round(4)})
    print(reg)
© www.soinside.com 2019 - 2024. All rights reserved.