我试图根据潜在状态S = 0,1来估计具有切换平均值的AR(1)过程y,该过程y以固定转移概率的马尔可夫过程(如here)的形式发展.简而言之,它采取以下形式:

y_t - mu_{0/1} = phi * (y_{t-1} - mu_{0/1})+ epsilon_t

其中,如果STATE_T=0,则使用mu_0,如果STATE_T=1,则使用MU_1. 我正在使用带有DiscreteHMCGibbs的JAX/Numpyro(尽管带有潜伏态枚举的普通坚果会产生相同的结果),但我似乎无法使采样器正常工作.从我运行的所有诊断来看,似乎所有的超参数都停留在初始化值上,摘要相应地返回所有std==0.下面我有一个重现我的问题的MWE.我在实现过程中犯了什么明显的错误吗?

MWe:

import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.contrib.control_flow import scan
from numpyro.infer import MCMC, NUTS,DiscreteHMCGibbs
from jax import random, pure_callback
import jax
import numpy as np

def generate_synthetic_data(T=100, mu=[0, 5], phi=0.5, sigma=1.0, p=np.array([[0.95, 0.05], [0.1, 0.9]])):
    states = np.zeros(T, dtype=np.int32)
    y = np.zeros(T)
    current_state = np.random.choice([0, 1], p=[0.5, 0.5])
    states[0] = current_state
    y[0] = np.random.normal(mu[current_state], sigma)

    for t in range(1, T):
        current_state = np.random.choice([0, 1], p=p[current_state,:])
        states[t] = current_state
        y[t] = np.random.normal(mu[current_state] + phi * (y[t-1] - mu[current_state]), sigma)

    return y, states


def mean_switching_AR1_model(y):
    T = len(y)
    phi = numpyro.sample('phi', dist.Normal(0, 1))
    sigma = numpyro.sample('sigma', dist.Exponential(1))
    
    
    with numpyro.plate('state_plate', 2):
        mu = numpyro.sample('mu', dist.Normal(0, 5))
        p = numpyro.sample('p', dist.Dirichlet(jnp.ones(2)))

    probs_init = numpyro.sample('probs_init', dist.Dirichlet(jnp.ones(2)))
    s_0 = numpyro.sample('s_0', dist.Categorical(probs_init))

    def transition_fn(carry, y_t):
        prev_state = carry
        state_probs = p[prev_state]
        state = numpyro.sample('state', dist.Categorical(state_probs))

        mu_state = mu[state]
        y_mean = mu_state + phi * (y_t - mu_state)
        y_next = numpyro.sample('y_next', dist.Normal(y_mean, sigma), obs=y_t)
        return state, (state, y_next)

    _ , (signal, y)=scan(transition_fn, s_0, y[:-1], length=T-1)
    return (signal, y)

# Synthetic data generation
T = 1000
mu_true = [0, 3]
phi_true = 0.5
sigma_true = 0.25
transition_matrix_true = np.array([[0.95, 0.05], [0.1, 0.9]])
y, states_true = generate_synthetic_data(T, mu=mu_true, phi=phi_true, sigma=sigma_true, p=transition_matrix_true)


rng_key = random.PRNGKey(0)
nuts_kernel = NUTS(mean_switching_AR1_model)
gibbs_kernel = DiscreteHMCGibbs(nuts_kernel, modified=True)

# Run MCMC
mcmc = MCMC(gibbs_kernel, num_samples=1000, num_warmup=1000)
mcmc.run(rng_key, y=y)
mcmc.print_summary()

推荐答案

所以事实证明,这确实存在一个很明显的错误,我没有正确地把y_{t-1}作为状态变量的一部分.以下修正后的转换函数可以毫无问题地产生预期结果.

def transition_fn(carry, y_curr):
    prev_state, y_prev = carry
    state_probs = p[prev_state]
    state = numpyro.sample('state', dist.Categorical(state_probs))

    mu_state = mu[state]
    y_mean = mu_state + phi * (y_prev - mu_state)
    y_curr = numpyro.sample('y_curr', dist.Normal(y_mean, sigma), obs=y_curr)
    return (state, y_curr), (state, y_curr)

_, (signal, y) = scan(transition_fn, (s_0, y[0]), y[1:], length=T-1)

Python相关问答推荐

使用Keras的线性回归参数估计

在内部列表上滚动窗口

'discord.ext. commanders.cog没有属性监听器'

海运图:调整行和列标签

Pandas计数符合某些条件的特定列的数量

给定高度约束的旋转角解析求解

如何使用scipy的curve_fit与约束,其中拟合的曲线总是在观测值之下?

调用decorator返回原始函数的输出

在两极中过滤

python panda ExcelWriter切换动态公式到数组公式

如何按row_id/row_number过滤数据帧

如何删除重复的文字翻拍?

判断Python操作:如何从字面上得到所有decorator ?

高效生成累积式三角矩阵

如何根据一定条件生成段id

Regex用于匹配Python中逗号分隔的AWS区域

将像素信息写入文件并读取该文件

是否将Pandas 数据帧标题/标题以纯文本格式转换为字符串输出?

利用广播使减法更有效率

更新包含整数范围的列表中的第一个元素