grammar-inference-engine/bex/baum_welch.py

193 lines
5.8 KiB
Python
Raw Normal View History

"""Baum-Welch for POMM on k-OA — standard forward-backward (Rabiner 1989)."""
import random
import math
def init_probabilities(G, sequences):
"""Initialize α per iKoa init (Algorithm 1, line 1).
α(src, sink) = fraction of empty words in S
α(src, s) = fraction of words starting with lab(s), split equally
among all k copies of that symbol
α(s, t) for s src: chosen randomly, normalized to sum to 1
"""
total = len(sequences)
if total == 0:
total = 1
empty_count = sum(1 for s in sequences if not s)
start_counts = {}
for seq in sequences:
if seq:
start_counts[seq[0]] = start_counts.get(seq[0], 0) + 1
prob = {}
for s in G._succ:
if s == G.sink:
continue
succ = list(G._succ[s])
if not succ:
prob[s] = {}
continue
vals = []
for t in succ:
if s == G.src:
if t == G.sink:
v = empty_count / total
else:
lab = G.label(t)
base = lab.rsplit('_', 1)[0] if '_' in lab else lab
count = start_counts.get(base, 0)
copies = sum(1 for u in succ if G.label(u) == lab)
v = (count / total) / max(copies, 1)
vals.append(v)
else:
vals.append(random.random())
s_total = sum(vals)
if s_total == 0:
vals = [1.0 / len(vals)] * len(vals)
else:
vals = [v / s_total for v in vals]
prob[s] = {t: v for t, v in zip(succ, vals)}
for s in prob:
for t in prob[s]:
if prob[s][t] < 1e-10:
prob[s][t] = 0.0
return prob
def bw_iteration(prob, sequences, node_to_idx, n_states, all_nodes, G):
"""Single Baum-Welch iteration over all sequences."""
total_num = {}
total_denom = {}
for seq in sequences:
if not seq:
continue
T = len(seq)
obs = seq
# which states can emit each observation? (keyed by base symbol)
emit = {}
for n in all_nodes:
lab = G.label(n)
if lab:
base = lab.rsplit('_', 1)[0] if '_' in lab else lab
emit.setdefault(base, []).append(n)
# sink emits nothing
sink = G.sink
# Forward pass
alpha = [{} for _ in range(T + 1)]
alpha[0][G.src] = 1.0
for t in range(T):
sym = obs[t]
possible = emit.get(sym, [])
for j in possible:
total = 0.0
for i in alpha[t]:
p_trans = prob.get(i, {}).get(j, 0.0)
if p_trans > 0:
total += alpha[t][i] * p_trans
if total > 0:
alpha[t + 1][j] = total
# P(O | λ)
po = 0.0
for i in alpha[T]:
po += alpha[T][i] * prob.get(i, {}).get(sink, 0.0)
if po == 0:
continue
# Backward pass
beta = [{} for _ in range(T + 1)]
for i in all_nodes:
if prob.get(i, {}).get(sink, 0.0) > 0:
beta[T][i] = prob[i][sink]
for t in range(T - 1, -1, -1):
sym = obs[t] if t < T else None
possible = emit.get(sym, []) if sym else []
for i in alpha[t]:
total = 0.0
for j in possible:
p_trans = prob.get(i, {}).get(j, 0.0)
if p_trans > 0 and j in beta[t + 1]:
total += p_trans * beta[t + 1][j]
if total > 0:
beta[t][i] = total
# Accumulate ξ and γ
for t in range(T):
sym_nxt = obs[t]
possible = emit.get(sym_nxt, [])
for i in alpha[t]:
if i not in beta[t] or beta[t][i] == 0:
continue
for j in possible:
p_trans = prob.get(i, {}).get(j, 0.0)
if p_trans == 0 or j not in beta[t + 1] or beta[t + 1][j] == 0:
continue
xi = alpha[t][i] * p_trans * beta[t + 1][j] / po
if xi > 1e-15:
key = (i, j)
total_num[key] = total_num.get(key, 0.0) + xi
total_denom[i] = total_denom.get(i, 0.0) + xi
# M-step: update probabilities
for s in prob:
for t in prob[s]:
key = (s, t)
d = total_denom.get(s, 0.0)
if d > 1e-15 and key in total_num:
prob[s][t] = total_num[key] / d
else:
prob[s][t] = 0.0
# Renormalize
for s in prob:
row_sum = sum(prob[s].values())
if row_sum > 1e-10:
for t in prob[s]:
prob[s][t] /= row_sum
else:
n_succ = len(prob[s])
for t in prob[s]:
prob[s][t] = 1.0 / n_succ
return prob
def baum_welch(G, prob, sequences, iterations=10):
"""Baum-Welch EM training.
Args:
G: k-OA graph
prob: dict[s][t] = transition probabilities
sequences: list of token lists (bag, not set)
iterations: number of EM iterations (full convergence)
Returns:
Updated prob dict
"""
all_nodes = list(G._succ.keys())
node_to_idx = {n: i for i, n in enumerate(all_nodes)}
n_states = len(all_nodes)
for _ in range(iterations):
prob = bw_iteration(prob, sequences, node_to_idx, n_states, all_nodes, G)
return prob
def baum_welch_fixed(G, prob, sequences, iterations=2):
"""Baum-Welch with fixed small iteration count (for Disambiguate).
= 2 for |Σ| 7, = 3 for |Σ| > 7.
"""
return baum_welch(G, prob, sequences, iterations)