极大似然法拟合概率分布

极大似然拟合概率分布

对数据集应用概率模型是解释数据集的一个好方法,但是,如何找到一个合适的模型本身就是一项工作。在选定模型之后,还要将其与数据进行比较或者检验。在这个例子当中,我们针对statsmodels自带的数据集“心脏移植后存活时间(1967-1974)”,采用极大似然估计的方法拟合概率分布。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import numpy as np
import pandas as pd
import scipy.stats as st
import statsmodels.datasets as datasets

from plotnine.ggplot import *
from plotnine.qplot import *
from plotnine.geoms import *
from plotnine.coords import *
from plotnine.labels import *
from plotnine.facets import *
from plotnine.scales import *
from plotnine.themes import *


import plotly.graph_objs as go
import plotly.offline as py_offline
import plotly.plotly as py
from plotly import tools

py_offline.init_notebook_mode()
%matplotlib inline
1
data = datasets.heart.load_pandas().data
1
2
data = data[data.censors==1]
qplot(x=data.survival.index, y=sorted(data.survival)[::-1], xlab='Patient', ylab='Survival time')

output_3_0.png

<ggplot: (158579004273)>
1
qplot(x=data.survival,bins=12,geom='histogram', xlab='Patient', ylab='Survival time')

output_4_0.png

<ggplot: (-9223371878275771602)>
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
trace1 = go.Scatter(
x=data.index,
y=sorted(data.survival)[::-1],
mode='markers',
name='散点图'
)

trace2 = go.Histogram(x=data.survival, name='直方图')

fig = tools.make_subplots(rows=1, cols=2)
fig.append_trace(trace1, 1, 1)

fig.append_trace(trace2, 1, 2)
fig['layout']['xaxis1'].update(title='病人')
fig['layout']['xaxis2'].update(title='存活时间(天)')

fig['layout']['yaxis1'].update(title='存活时间(天)')
fig['layout']['yaxis2'].update(title='病人数目')


fig['layout'].update(height=600, width=1000, title='病人存货情况')
py_offline.iplot(fig)
This is the format of your plot grid:
[ (1,1) x1,y1 ]  [ (1,2) x2,y2 ]

通过原始数据的散点图和直方图,我们可以看到,绝大部分的存活时间在心脏移植后不超过3年,当然这是上个世纪六七十年代的数据,时至今日,今天的存活率和存活时间已经打大提高。

从直方图上可以看出,存活时间的频数随着天数的增加快速下降,因此考虑采用指数分布来拟合数据。

一个指数分布的概率密度函数是:

f(x;λ)={λeλx,x0,0,x<0.f(x;\lambda)= \begin{cases} \lambda e^{-\lambda x}, &x\geq0,\\ 0, &x<0. \end{cases}

其中λ>0\lambda>0是分布的一个参数,常被称为率参数(rate parameter),λ\lambda的倒数被称为scale参数。

假设存活天数的数目为ssss是服从参数为λ\lambda的随机变量,那么根据极大似然法容易得到,λ\lambda的极大似然估计为1sˉ\frac{1}{\bar{s}},即ss的样本均值。

1
2
3
4
5
6
survival_mean = data.survival.mean()
rate = 1. / survival_mean
smax = data.survival.max()
days = np.linspace(0., smax, 1000)
# bin size: interval between two
# consecutive values in `days` dt = smax / 999.

这样,我们就得到了拟合分布。

1
dist_exp = st.expon.pdf(days, scale=survival_mean)

然后将拟合分布的概率密度函数与原始数据的直方图进行比较,需要注意的是,一个是概率,一个是频数,所以需要转换为同样的标准。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
nbins = 30
trace0 = go.Histogram(
x=data.survival,
nbinsx=nbins,
name = 'Emrpical data',
)

trace1 = go.Scatter(
x=days,
y=dist_exp*len(data.index)*smax/nbins,
name='Fitted PDF',
mode='lines',
)

data_plot=[trace0, trace1]
# Edit the layout
layout = dict(xaxis = dict(title = 'Survival time (days)'),
yaxis = dict(title = 'Number of patients'),
)

fig = go.Figure(data=data_plot, layout=layout)
py_offline.iplot(fig)

从上图可以看出,直接用指数分布拟合的结果并不是很理想,这可能是由于极大似然估计本身的缺陷(在假设情况下只用到了样本的均值),此外,对于指数分布MLE能够得到解析解,但是有些分布可能得不到解析解,就需要通过EM算法等近似求解数值解。
那么,我们通过scipy数值求解指数分布的参数:

1
2
3
dist = st.expon
args = dist.fit(data.survival)
args
(1.0, 222.2888888888889)

正如我们在极值分析里所做的一样,我们可以用K-S检验衡量分布对数据的拟合优度。

1
st.kstest(data.survival, dist.cdf, args)
KstestResult(statistic=0.36199693810792966, pvalue=8.647045785181717e-06)

如此小的p值,意味着拒绝原假设(样本的分布与拟合分布相同),即两者存在这显著的差异,说明不应该指数分布拟合。

改用另外一种分布Brinbaum-Sanders distribution,这种分布通常用来拟合疲劳寿命或者失效时间。

1
2
3
dist = st.fatiguelife
args = dist.fit(data.survival)
st.kstest(data.survival, dist.cdf, args)
KstestResult(statistic=0.1877344610194689, pvalue=0.07321149700086327)

由上可知,p值为0.073,在5%的置信水平上无法拒绝原假设,说明BS分布比指数分布更加合适。另一方面需要注意的是,在scipy中,并不是直接采用原始数据进行拟合的,而是利用loc和scale两个参数进行规范化后,与标准分布进行拟合,这一点在实际应用中需要注意。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
dist_bs = st.fatiguelife.pdf(days, *args)
trace0 = go.Histogram(
x=data.survival,
nbinsx=nbins,
name = 'Emrpical data',
)

trace1 = go.Scatter(
x=days,
y=dist_exp*len(data.index)*smax/nbins,
name='Fitted by EXP',
mode='lines',
)


trace2 = go.Scatter(
x=days,
y=dist_bs*len(data.index)*smax/nbins,
name='Fitted by BS',
mode='lines',
)


data_plot=[trace0, trace1, trace2]
# Edit the layout
layout = dict(xaxis = dict(title = 'Survival time (days)'),
yaxis = dict(title = 'Number of patients'),
)

fig = go.Figure(data=data_plot, layout=layout)
py_offline.iplot(fig)

张da统帅 wechat
扫码订阅我的微信公众号