Python 包:Dirichlet 分布的 MLE

问题描述 投票:0回答:2


python statistics mle dirichlet

Eric Suh 有一个包裹这里

$ pip install git+


import numpy
import dirichlet
a0 = numpy.array([100, 299, 100])
D0 = numpy.random.dirichlet(a0, 1000)



from sklearn.base import BaseEstimator, RegressorMixin
from scipy.special import polygamma
import pandas as pd
import numpy as np

class DirichletMLE(BaseEstimator, RegressorMixin):
    Implements a static Dirichlet regressor compliant with the Sklearn API
    providing the MLE estimates through the `fit` method.

    `eps: float = 0.00001`
        An optional error tolerance parameter for the numerical solver.
        Defaults to `0.00001`.
    `maxiter: int = 100`
        An optional fail-safe parameter limiting the number of iterations the
        numerical solver can perform. Defaults to `100`.

        If the error tolerance is non-positive.
        If the maximum number of iterations is non-positive.

    def __init__(self, eps: float = 0.00001, maxiter: int = 100) -> None:
        assert eps > 0.0, 'Error tolerance must be positive'
        assert maxiter > 0, 'Maximum number of iterations must be positive'

        self.eps = eps
        self.maxiter = maxiter
        self.ndim = None
        self.niter_ = None
        self.status = None
        self.shape = None
        self.mean_array = None

    def __str__(self) -> str:
        return str({
            'shape': None if self.shape is None else np.round(self.shape, 4),
            'niter': self.niter, 'status': self.status

    def fit(self, y: pd.DataFrame):
        ''' \
        Calibrates the regressor by finding the maximum likelihood estimates
        by way of Newton's method.

    `y: pd.DataFrame`
        A response matrix.

        If one or more response values are negative.
        If one or more response arrays does not sum to 100%.
    assert y.min() >= 0.0, 'The response matrix is out of bounds'
    assert all((y.sum(axis=1) - 1.0) <= self.eps), 'The row-wise sum of the response must equal 100%'  # noqa: E501

    ndim = y.shape[1]
    niter = 0
    check = False
    logy_mean = np.log(y).mean(axis=0)

    a = np.zeros(ndim)
    g = np.zeros(ndim)
    h = np.zeros((ndim, ndim))

    # Method of moments seeds
    mean = y.mean(axis=0)
    phi = (mean * (1.0 - mean)).sum() / ((y - mean) ** 2.0).sum(axis=1).mean()  # noqa: E501
    a = mean * phi

    while not (check or niter > self.maxiter):
        niter += 1

        g = polygamma(n=0, x=a.sum()) * np.ones(ndim) + logy_mean - polygamma(n=0, x=a)  # noqa: E501
        h = polygamma(n=1, x=a.sum()) * np.ones((ndim, ndim)) - np.diag(polygamma(n=1, x=a))  # noqa: E501

        if np.linalg.det(a=h) == 0.0:

        delta = -np.linalg.inv(a=h) @ g
        check = np.linalg.norm(x=delta) <= self.eps
        a += delta

    self.ndim_ = ndim
    self.shape = a
    self.mean_array = self.shape / self.shape.sum()
    self.niter = niter
    self.status = int(check)

    return self

def predict(self, X: pd.DataFrame | None = None) -> np.ndarray:
    ''' \
    Predicts the mean response array (decision rule consistent with the use
    of quadratic loss function).

    `X: pd.DataFrame | None = None`
        An optional dummy feature matrix to provide compliance with the
        Sklearn API.

        An array of prediction.
    assert self.status == 1, 'Regressor must be calibrated prior to predicting'  # noqa: E501

    if X is not None:
        return self.mean_array * np.ones((len(X), self.ndim))

    return self.mean_array

if __name__ == '__main__':
    size = 10_000
    ndim = 10
    shape = np.random.uniform(low=0.0, high=10.0, size=ndim)
    y = np.random.dirichlet(alpha=shape, size=size)
    reg = DirichletMLE().fit(y=y)

    print({'shape': shape.round(4)})
© 2019 - 2024. All rights reserved.