grammar-inference-engine/bex/baum_welch.py
tobjend 7c00c6713d Initial commit: BEX-based grammar inference engine
- CRX: direct CHARE inference (Algorithm 7, TODS 2010)
- iDRegEx: k-ORE inference (Algorithm 4, arXiv 2010)
- RWR₀: SORE repair (Algorithm 6, TODS 2010)
- rwr²: k-ORE extraction (Algorithm 3, arXiv 2010)
- SOA, k-OA, iKoa, 2T-INF, Baum-Welch
- Ansible role grammar adapter
- Generic YAML key-path converter
- 28 tests, all passing
2026-07-01 08:01:16 +02:00

192 lines
5.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Baum-Welch for POMM on k-OA — standard forward-backward (Rabiner 1989)."""
import random
import math
def init_probabilities(G, sequences):
"""Initialize α per iKoa init (Algorithm 1, line 1).
α(src, sink) = fraction of empty words in S
α(src, s) = fraction of words starting with lab(s), split equally
among all k copies of that symbol
α(s, t) for s ≠ src: chosen randomly, normalized to sum to 1
"""
total = len(sequences)
if total == 0:
total = 1
empty_count = sum(1 for s in sequences if not s)
start_counts = {}
for seq in sequences:
if seq:
start_counts[seq[0]] = start_counts.get(seq[0], 0) + 1
prob = {}
for s in G._succ:
if s == G.sink:
continue
succ = list(G._succ[s])
if not succ:
prob[s] = {}
continue
vals = []
for t in succ:
if s == G.src:
if t == G.sink:
v = empty_count / total
else:
lab = G.label(t)
base = lab.rsplit('_', 1)[0] if '_' in lab else lab
count = start_counts.get(base, 0)
copies = sum(1 for u in succ if G.label(u) == lab)
v = (count / total) / max(copies, 1)
vals.append(v)
else:
vals.append(random.random())
s_total = sum(vals)
if s_total == 0:
vals = [1.0 / len(vals)] * len(vals)
else:
vals = [v / s_total for v in vals]
prob[s] = {t: v for t, v in zip(succ, vals)}
for s in prob:
for t in prob[s]:
if prob[s][t] < 1e-10:
prob[s][t] = 0.0
return prob
def bw_iteration(prob, sequences, node_to_idx, n_states, all_nodes, G):
"""Single Baum-Welch iteration over all sequences."""
total_num = {}
total_denom = {}
for seq in sequences:
if not seq:
continue
T = len(seq)
obs = seq
# which states can emit each observation? (keyed by base symbol)
emit = {}
for n in all_nodes:
lab = G.label(n)
if lab:
base = lab.rsplit('_', 1)[0] if '_' in lab else lab
emit.setdefault(base, []).append(n)
# sink emits nothing
sink = G.sink
# Forward pass
alpha = [{} for _ in range(T + 1)]
alpha[0][G.src] = 1.0
for t in range(T):
sym = obs[t]
possible = emit.get(sym, [])
for j in possible:
total = 0.0
for i in alpha[t]:
p_trans = prob.get(i, {}).get(j, 0.0)
if p_trans > 0:
total += alpha[t][i] * p_trans
if total > 0:
alpha[t + 1][j] = total
# P(O | λ)
po = 0.0
for i in alpha[T]:
po += alpha[T][i] * prob.get(i, {}).get(sink, 0.0)
if po == 0:
continue
# Backward pass
beta = [{} for _ in range(T + 1)]
for i in all_nodes:
if prob.get(i, {}).get(sink, 0.0) > 0:
beta[T][i] = prob[i][sink]
for t in range(T - 1, -1, -1):
sym = obs[t] if t < T else None
possible = emit.get(sym, []) if sym else []
for i in alpha[t]:
total = 0.0
for j in possible:
p_trans = prob.get(i, {}).get(j, 0.0)
if p_trans > 0 and j in beta[t + 1]:
total += p_trans * beta[t + 1][j]
if total > 0:
beta[t][i] = total
# Accumulate ξ and γ
for t in range(T):
sym_nxt = obs[t]
possible = emit.get(sym_nxt, [])
for i in alpha[t]:
if i not in beta[t] or beta[t][i] == 0:
continue
for j in possible:
p_trans = prob.get(i, {}).get(j, 0.0)
if p_trans == 0 or j not in beta[t + 1] or beta[t + 1][j] == 0:
continue
xi = alpha[t][i] * p_trans * beta[t + 1][j] / po
if xi > 1e-15:
key = (i, j)
total_num[key] = total_num.get(key, 0.0) + xi
total_denom[i] = total_denom.get(i, 0.0) + xi
# M-step: update probabilities
for s in prob:
for t in prob[s]:
key = (s, t)
d = total_denom.get(s, 0.0)
if d > 1e-15 and key in total_num:
prob[s][t] = total_num[key] / d
else:
prob[s][t] = 0.0
# Renormalize
for s in prob:
row_sum = sum(prob[s].values())
if row_sum > 1e-10:
for t in prob[s]:
prob[s][t] /= row_sum
else:
n_succ = len(prob[s])
for t in prob[s]:
prob[s][t] = 1.0 / n_succ
return prob
def baum_welch(G, prob, sequences, iterations=10):
"""Baum-Welch EM training.
Args:
G: k-OA graph
prob: dict[s][t] = transition probabilities
sequences: list of token lists (bag, not set)
iterations: number of EM iterations (full convergence)
Returns:
Updated prob dict
"""
all_nodes = list(G._succ.keys())
node_to_idx = {n: i for i, n in enumerate(all_nodes)}
n_states = len(all_nodes)
for _ in range(iterations):
prob = bw_iteration(prob, sequences, node_to_idx, n_states, all_nodes, G)
return prob
def baum_welch_fixed(G, prob, sequences, iterations=2):
"""Baum-Welch with fixed small iteration count (for Disambiguate).
= 2 for |Σ| ≤ 7, = 3 for |Σ| > 7.
"""
return baum_welch(G, prob, sequences, iterations)