- CRX: direct CHARE inference (Algorithm 7, TODS 2010) - iDRegEx: k-ORE inference (Algorithm 4, arXiv 2010) - RWR₀: SORE repair (Algorithm 6, TODS 2010) - rwr²: k-ORE extraction (Algorithm 3, arXiv 2010) - SOA, k-OA, iKoa, 2T-INF, Baum-Welch - Ansible role grammar adapter - Generic YAML key-path converter - 28 tests, all passing
192 lines
5.8 KiB
Python
192 lines
5.8 KiB
Python
"""Baum-Welch for POMM on k-OA — standard forward-backward (Rabiner 1989)."""
|
||
|
||
import random
|
||
import math
|
||
|
||
|
||
def init_probabilities(G, sequences):
|
||
"""Initialize α per iKoa init (Algorithm 1, line 1).
|
||
|
||
— α(src, sink) = fraction of empty words in S
|
||
— α(src, s) = fraction of words starting with lab(s), split equally
|
||
among all k copies of that symbol
|
||
— α(s, t) for s ≠ src: chosen randomly, normalized to sum to 1
|
||
"""
|
||
total = len(sequences)
|
||
if total == 0:
|
||
total = 1
|
||
empty_count = sum(1 for s in sequences if not s)
|
||
|
||
start_counts = {}
|
||
for seq in sequences:
|
||
if seq:
|
||
start_counts[seq[0]] = start_counts.get(seq[0], 0) + 1
|
||
|
||
prob = {}
|
||
for s in G._succ:
|
||
if s == G.sink:
|
||
continue
|
||
succ = list(G._succ[s])
|
||
if not succ:
|
||
prob[s] = {}
|
||
continue
|
||
vals = []
|
||
for t in succ:
|
||
if s == G.src:
|
||
if t == G.sink:
|
||
v = empty_count / total
|
||
else:
|
||
lab = G.label(t)
|
||
base = lab.rsplit('_', 1)[0] if '_' in lab else lab
|
||
count = start_counts.get(base, 0)
|
||
copies = sum(1 for u in succ if G.label(u) == lab)
|
||
v = (count / total) / max(copies, 1)
|
||
vals.append(v)
|
||
else:
|
||
vals.append(random.random())
|
||
s_total = sum(vals)
|
||
if s_total == 0:
|
||
vals = [1.0 / len(vals)] * len(vals)
|
||
else:
|
||
vals = [v / s_total for v in vals]
|
||
prob[s] = {t: v for t, v in zip(succ, vals)}
|
||
|
||
for s in prob:
|
||
for t in prob[s]:
|
||
if prob[s][t] < 1e-10:
|
||
prob[s][t] = 0.0
|
||
|
||
return prob
|
||
|
||
|
||
def bw_iteration(prob, sequences, node_to_idx, n_states, all_nodes, G):
|
||
"""Single Baum-Welch iteration over all sequences."""
|
||
total_num = {}
|
||
total_denom = {}
|
||
|
||
for seq in sequences:
|
||
if not seq:
|
||
continue
|
||
T = len(seq)
|
||
obs = seq
|
||
|
||
# which states can emit each observation? (keyed by base symbol)
|
||
emit = {}
|
||
for n in all_nodes:
|
||
lab = G.label(n)
|
||
if lab:
|
||
base = lab.rsplit('_', 1)[0] if '_' in lab else lab
|
||
emit.setdefault(base, []).append(n)
|
||
# sink emits nothing
|
||
sink = G.sink
|
||
|
||
# Forward pass
|
||
alpha = [{} for _ in range(T + 1)]
|
||
alpha[0][G.src] = 1.0
|
||
|
||
for t in range(T):
|
||
sym = obs[t]
|
||
possible = emit.get(sym, [])
|
||
for j in possible:
|
||
total = 0.0
|
||
for i in alpha[t]:
|
||
p_trans = prob.get(i, {}).get(j, 0.0)
|
||
if p_trans > 0:
|
||
total += alpha[t][i] * p_trans
|
||
if total > 0:
|
||
alpha[t + 1][j] = total
|
||
|
||
# P(O | λ)
|
||
po = 0.0
|
||
for i in alpha[T]:
|
||
po += alpha[T][i] * prob.get(i, {}).get(sink, 0.0)
|
||
if po == 0:
|
||
continue
|
||
|
||
# Backward pass
|
||
beta = [{} for _ in range(T + 1)]
|
||
for i in all_nodes:
|
||
if prob.get(i, {}).get(sink, 0.0) > 0:
|
||
beta[T][i] = prob[i][sink]
|
||
|
||
for t in range(T - 1, -1, -1):
|
||
sym = obs[t] if t < T else None
|
||
possible = emit.get(sym, []) if sym else []
|
||
for i in alpha[t]:
|
||
total = 0.0
|
||
for j in possible:
|
||
p_trans = prob.get(i, {}).get(j, 0.0)
|
||
if p_trans > 0 and j in beta[t + 1]:
|
||
total += p_trans * beta[t + 1][j]
|
||
if total > 0:
|
||
beta[t][i] = total
|
||
|
||
# Accumulate ξ and γ
|
||
for t in range(T):
|
||
sym_nxt = obs[t]
|
||
possible = emit.get(sym_nxt, [])
|
||
for i in alpha[t]:
|
||
if i not in beta[t] or beta[t][i] == 0:
|
||
continue
|
||
for j in possible:
|
||
p_trans = prob.get(i, {}).get(j, 0.0)
|
||
if p_trans == 0 or j not in beta[t + 1] or beta[t + 1][j] == 0:
|
||
continue
|
||
xi = alpha[t][i] * p_trans * beta[t + 1][j] / po
|
||
if xi > 1e-15:
|
||
key = (i, j)
|
||
total_num[key] = total_num.get(key, 0.0) + xi
|
||
total_denom[i] = total_denom.get(i, 0.0) + xi
|
||
|
||
# M-step: update probabilities
|
||
for s in prob:
|
||
for t in prob[s]:
|
||
key = (s, t)
|
||
d = total_denom.get(s, 0.0)
|
||
if d > 1e-15 and key in total_num:
|
||
prob[s][t] = total_num[key] / d
|
||
else:
|
||
prob[s][t] = 0.0
|
||
|
||
# Renormalize
|
||
for s in prob:
|
||
row_sum = sum(prob[s].values())
|
||
if row_sum > 1e-10:
|
||
for t in prob[s]:
|
||
prob[s][t] /= row_sum
|
||
else:
|
||
n_succ = len(prob[s])
|
||
for t in prob[s]:
|
||
prob[s][t] = 1.0 / n_succ
|
||
|
||
return prob
|
||
|
||
|
||
def baum_welch(G, prob, sequences, iterations=10):
|
||
"""Baum-Welch EM training.
|
||
|
||
Args:
|
||
G: k-OA graph
|
||
prob: dict[s][t] = transition probabilities
|
||
sequences: list of token lists (bag, not set)
|
||
iterations: number of EM iterations (full convergence)
|
||
|
||
Returns:
|
||
Updated prob dict
|
||
"""
|
||
all_nodes = list(G._succ.keys())
|
||
node_to_idx = {n: i for i, n in enumerate(all_nodes)}
|
||
n_states = len(all_nodes)
|
||
|
||
for _ in range(iterations):
|
||
prob = bw_iteration(prob, sequences, node_to_idx, n_states, all_nodes, G)
|
||
|
||
return prob
|
||
|
||
|
||
def baum_welch_fixed(G, prob, sequences, iterations=2):
|
||
"""Baum-Welch with fixed small iteration count (for Disambiguate).
|
||
|
||
ℓ = 2 for |Σ| ≤ 7, ℓ = 3 for |Σ| > 7.
|
||
"""
|
||
return baum_welch(G, prob, sequences, iterations)
|