如何使用pgmpy获得贝叶斯网络中新事件的概率?

问题描述 投票:0回答:2

我使用 pgmpy 库训练了贝叶斯网络。我希望找到一个新事件的联合概率(作为给定其父事件的每个变量的概率的乘积,如果有的话)。

目前我正在做

infer = VariableElimination(model)
evidence = dict(x_test.iloc[0])
result = infer.query(variables=[], evidence=evidence, joint=True)
print(result)

这里

x_test
是测试数据框。

result
是非常大的输出,包含训练数据及其概率的所有组合。

+----------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------+------------------------------------------+-----------------+---------------------------+-----------------------------------------+------------------------------+------------------------+---------------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+
| data_devicetype                                                                                                                              | data_username                      | data_applicationtype                     | event_type      | servicename               | data_applicationname                    | tenantname                   | data_origin            | geoip_country_name        |   phi(data_devicetype,data_username,data_applicationtype,event_type,servicename,data_applicationname,tenantname,data_origin,geoip_country_name) |
+==============================================================================================================================================+====================================+==========================================+=================+===========================+=========================================+==============================+========================+===========================+=================================================================================================================================================+
| data_devicetype(Mozilla_5_0_Windows_NT_10_0_Win64_x64_AppleWebKit_537_36_KHTML_like_Gecko_Chrome_94_0_4606_81_Safari_537_36)                 | data_username(christofer) | data_applicationtype(Custom_Application) | event_type(sso) | servicename(saml_runtime) | data_applicationname(GD)            | tenantname(amx-sni-ksll0) | data_origin(1_0_64_66) | geoip_country_name(Japan) |                                                                                                                                          0.0326 |
+----------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------+------------------------------------------+-----------------+---------------------------+-----------------------------------------+------------------------------+------------------------+---------------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+
| data_devicetype(Mozilla_5_0_Windows_NT_10_0_Win64_x64_AppleWebKit_537_36_KHTML_like_Gecko_Chrome_94_0_4606_81_Safari_537_36)                 | data_username(marty) | data_applicationtype(Custom_Application) | event_type(sso) | servicename(saml_runtime) | data_applicationname(VAULT)      | tenantname(login_pqr_com) | data_origin(1_0_64_66) | geoip_country_name(Japan) |                                                                                                                                          0.0156 |
+----------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------+------------------------------------------+-----------------+---------------------------+-----------------------------------------+------------------------------+------------------------+---------------------------+-------------------------------------------------------------------------------------------------------------------------------------------------+
| data_devicetype(Mozilla_5_0_Windows_NT_10_0_Win64_x64_AppleWebKit_537_36_KHTML_like_Gecko_Chrome_94_0_4606_81_Safari_537_36)                 | data_username(lincon) | data_applicationtype(Custom_Application) | event_type(sso) | servicename(saml_runtime) | data_applicationname(apps_think4ch_com) | tenantname(login_abc_com) | data_origin(1_0_64_66) | geoip_country_name(Japan) |                                                                                                                                          0.0113 |
......contd

请帮助我如何找出新事件(即测试数据中的一行)的概率。概率表达式为

P(data_devicetype, data_username, data_applicationtype, event_type, servicename, data_applicationname, tenantname, data_origin, geoip_country_name)

python-3.x bayesian-networks pgmpy
2个回答
0
投票

如果我理解正确,那么您正在尝试计算新数据点的概率。不幸的是,在 pgmpy 中还没有直接的方法来做到这一点。虽然可以从推理结果中得到概率值。 像这样的东西:

infer = VariableElimination(model)
result = infer.query(variables=list(model.nodes()), joint=True)
evidence = dict(x_test.iloc[0])
p_evidence = result.get_value(**evidence)

本质上,这里我们计算所有变量的联合分布,然后获取

evidence
数据点的概率值。 正如您所期望的,在大型网络的情况下,这在计算上可能非常低效。在这种情况下,计算概率的近似方法是使用模拟。

nsamples = int(1e6)
samples = model.simulate(nsamples)
evidence = dict(x_test.iloc[0])
matching_samples = samples[np.logical_and.reduce([samples[k]==v for k, v in evidence.items()])]
p_evidence = matching_samples.shape[0] / nsamples

通过模拟方法,我们从模型中生成一些模拟数据,并检查其中有多少样本与我们的数据点匹配,这就是它的概率。


0
投票

我可能有点晚了,但如果它对某人有帮助,这里有一个关于如何将 pgmpy 查询的结果转换为 pandas df 以便进一步处理的解决方案:

from pgmpy.utils import get_example_model
from pgmpy.factors.discrete.DiscreteFactor import DiscreteFactor
import numpy as np
import itertools as it

def to_pandas(self) -> pd.DataFrame:
    """ Convert a DiscreteFactor to a pandas DataFrame."""
    states: dict = self.state_names
    values: np.ndarray = self.values

    state_combinations = it.product(*states.values())

    df = pd.DataFrame(state_combinations, columns=states.keys())
    df['prob'] = values.flatten()

    return df


# register function as a method to the DiscreteFactor class
DiscreteFactor.to_pandas = to_pandas


# Example usage
model = get_example_model(model="asia")
inf = VariableElimination(model)

result = inf.query(variables=["tub", "smoke"])
result_df = result.to_pandas()
print(result_df)
© www.soinside.com 2019 - 2024. All rights reserved.