193 lines
5.8 KiB
Python
193 lines
5.8 KiB
Python
|
|
"""Baum-Welch for POMM on k-OA — standard forward-backward (Rabiner 1989)."""
|
|||
|
|
|
|||
|
|
import random
|
|||
|
|
import math
|
|||
|
|
|
|||
|
|
|
|||
|
|
def init_probabilities(G, sequences):
|
|||
|
|
"""Initialize α per iKoa init (Algorithm 1, line 1).
|
|||
|
|
|
|||
|
|
— α(src, sink) = fraction of empty words in S
|
|||
|
|
— α(src, s) = fraction of words starting with lab(s), split equally
|
|||
|
|
among all k copies of that symbol
|
|||
|
|
— α(s, t) for s ≠ src: chosen randomly, normalized to sum to 1
|
|||
|
|
"""
|
|||
|
|
total = len(sequences)
|
|||
|
|
if total == 0:
|
|||
|
|
total = 1
|
|||
|
|
empty_count = sum(1 for s in sequences if not s)
|
|||
|
|
|
|||
|
|
start_counts = {}
|
|||
|
|
for seq in sequences:
|
|||
|
|
if seq:
|
|||
|
|
start_counts[seq[0]] = start_counts.get(seq[0], 0) + 1
|
|||
|
|
|
|||
|
|
prob = {}
|
|||
|
|
for s in G._succ:
|
|||
|
|
if s == G.sink:
|
|||
|
|
continue
|
|||
|
|
succ = list(G._succ[s])
|
|||
|
|
if not succ:
|
|||
|
|
prob[s] = {}
|
|||
|
|
continue
|
|||
|
|
vals = []
|
|||
|
|
for t in succ:
|
|||
|
|
if s == G.src:
|
|||
|
|
if t == G.sink:
|
|||
|
|
v = empty_count / total
|
|||
|
|
else:
|
|||
|
|
lab = G.label(t)
|
|||
|
|
base = lab.rsplit('_', 1)[0] if '_' in lab else lab
|
|||
|
|
count = start_counts.get(base, 0)
|
|||
|
|
copies = sum(1 for u in succ if G.label(u) == lab)
|
|||
|
|
v = (count / total) / max(copies, 1)
|
|||
|
|
vals.append(v)
|
|||
|
|
else:
|
|||
|
|
vals.append(random.random())
|
|||
|
|
s_total = sum(vals)
|
|||
|
|
if s_total == 0:
|
|||
|
|
vals = [1.0 / len(vals)] * len(vals)
|
|||
|
|
else:
|
|||
|
|
vals = [v / s_total for v in vals]
|
|||
|
|
prob[s] = {t: v for t, v in zip(succ, vals)}
|
|||
|
|
|
|||
|
|
for s in prob:
|
|||
|
|
for t in prob[s]:
|
|||
|
|
if prob[s][t] < 1e-10:
|
|||
|
|
prob[s][t] = 0.0
|
|||
|
|
|
|||
|
|
return prob
|
|||
|
|
|
|||
|
|
|
|||
|
|
def bw_iteration(prob, sequences, node_to_idx, n_states, all_nodes, G):
|
|||
|
|
"""Single Baum-Welch iteration over all sequences."""
|
|||
|
|
total_num = {}
|
|||
|
|
total_denom = {}
|
|||
|
|
|
|||
|
|
for seq in sequences:
|
|||
|
|
if not seq:
|
|||
|
|
continue
|
|||
|
|
T = len(seq)
|
|||
|
|
obs = seq
|
|||
|
|
|
|||
|
|
# which states can emit each observation? (keyed by base symbol)
|
|||
|
|
emit = {}
|
|||
|
|
for n in all_nodes:
|
|||
|
|
lab = G.label(n)
|
|||
|
|
if lab:
|
|||
|
|
base = lab.rsplit('_', 1)[0] if '_' in lab else lab
|
|||
|
|
emit.setdefault(base, []).append(n)
|
|||
|
|
# sink emits nothing
|
|||
|
|
sink = G.sink
|
|||
|
|
|
|||
|
|
# Forward pass
|
|||
|
|
alpha = [{} for _ in range(T + 1)]
|
|||
|
|
alpha[0][G.src] = 1.0
|
|||
|
|
|
|||
|
|
for t in range(T):
|
|||
|
|
sym = obs[t]
|
|||
|
|
possible = emit.get(sym, [])
|
|||
|
|
for j in possible:
|
|||
|
|
total = 0.0
|
|||
|
|
for i in alpha[t]:
|
|||
|
|
p_trans = prob.get(i, {}).get(j, 0.0)
|
|||
|
|
if p_trans > 0:
|
|||
|
|
total += alpha[t][i] * p_trans
|
|||
|
|
if total > 0:
|
|||
|
|
alpha[t + 1][j] = total
|
|||
|
|
|
|||
|
|
# P(O | λ)
|
|||
|
|
po = 0.0
|
|||
|
|
for i in alpha[T]:
|
|||
|
|
po += alpha[T][i] * prob.get(i, {}).get(sink, 0.0)
|
|||
|
|
if po == 0:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# Backward pass
|
|||
|
|
beta = [{} for _ in range(T + 1)]
|
|||
|
|
for i in all_nodes:
|
|||
|
|
if prob.get(i, {}).get(sink, 0.0) > 0:
|
|||
|
|
beta[T][i] = prob[i][sink]
|
|||
|
|
|
|||
|
|
for t in range(T - 1, -1, -1):
|
|||
|
|
sym = obs[t] if t < T else None
|
|||
|
|
possible = emit.get(sym, []) if sym else []
|
|||
|
|
for i in alpha[t]:
|
|||
|
|
total = 0.0
|
|||
|
|
for j in possible:
|
|||
|
|
p_trans = prob.get(i, {}).get(j, 0.0)
|
|||
|
|
if p_trans > 0 and j in beta[t + 1]:
|
|||
|
|
total += p_trans * beta[t + 1][j]
|
|||
|
|
if total > 0:
|
|||
|
|
beta[t][i] = total
|
|||
|
|
|
|||
|
|
# Accumulate ξ and γ
|
|||
|
|
for t in range(T):
|
|||
|
|
sym_nxt = obs[t]
|
|||
|
|
possible = emit.get(sym_nxt, [])
|
|||
|
|
for i in alpha[t]:
|
|||
|
|
if i not in beta[t] or beta[t][i] == 0:
|
|||
|
|
continue
|
|||
|
|
for j in possible:
|
|||
|
|
p_trans = prob.get(i, {}).get(j, 0.0)
|
|||
|
|
if p_trans == 0 or j not in beta[t + 1] or beta[t + 1][j] == 0:
|
|||
|
|
continue
|
|||
|
|
xi = alpha[t][i] * p_trans * beta[t + 1][j] / po
|
|||
|
|
if xi > 1e-15:
|
|||
|
|
key = (i, j)
|
|||
|
|
total_num[key] = total_num.get(key, 0.0) + xi
|
|||
|
|
total_denom[i] = total_denom.get(i, 0.0) + xi
|
|||
|
|
|
|||
|
|
# M-step: update probabilities
|
|||
|
|
for s in prob:
|
|||
|
|
for t in prob[s]:
|
|||
|
|
key = (s, t)
|
|||
|
|
d = total_denom.get(s, 0.0)
|
|||
|
|
if d > 1e-15 and key in total_num:
|
|||
|
|
prob[s][t] = total_num[key] / d
|
|||
|
|
else:
|
|||
|
|
prob[s][t] = 0.0
|
|||
|
|
|
|||
|
|
# Renormalize
|
|||
|
|
for s in prob:
|
|||
|
|
row_sum = sum(prob[s].values())
|
|||
|
|
if row_sum > 1e-10:
|
|||
|
|
for t in prob[s]:
|
|||
|
|
prob[s][t] /= row_sum
|
|||
|
|
else:
|
|||
|
|
n_succ = len(prob[s])
|
|||
|
|
for t in prob[s]:
|
|||
|
|
prob[s][t] = 1.0 / n_succ
|
|||
|
|
|
|||
|
|
return prob
|
|||
|
|
|
|||
|
|
|
|||
|
|
def baum_welch(G, prob, sequences, iterations=10):
|
|||
|
|
"""Baum-Welch EM training.
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
G: k-OA graph
|
|||
|
|
prob: dict[s][t] = transition probabilities
|
|||
|
|
sequences: list of token lists (bag, not set)
|
|||
|
|
iterations: number of EM iterations (full convergence)
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
Updated prob dict
|
|||
|
|
"""
|
|||
|
|
all_nodes = list(G._succ.keys())
|
|||
|
|
node_to_idx = {n: i for i, n in enumerate(all_nodes)}
|
|||
|
|
n_states = len(all_nodes)
|
|||
|
|
|
|||
|
|
for _ in range(iterations):
|
|||
|
|
prob = bw_iteration(prob, sequences, node_to_idx, n_states, all_nodes, G)
|
|||
|
|
|
|||
|
|
return prob
|
|||
|
|
|
|||
|
|
|
|||
|
|
def baum_welch_fixed(G, prob, sequences, iterations=2):
|
|||
|
|
"""Baum-Welch with fixed small iteration count (for Disambiguate).
|
|||
|
|
|
|||
|
|
ℓ = 2 for |Σ| ≤ 7, ℓ = 3 for |Σ| > 7.
|
|||
|
|
"""
|
|||
|
|
return baum_welch(G, prob, sequences, iterations)
|