"""Baum-Welch for POMM on k-OA — standard forward-backward (Rabiner 1989).""" import random import math def init_probabilities(G, sequences): """Initialize α per iKoa init (Algorithm 1, line 1). — α(src, sink) = fraction of empty words in S — α(src, s) = fraction of words starting with lab(s), split equally among all k copies of that symbol — α(s, t) for s ≠ src: chosen randomly, normalized to sum to 1 """ total = len(sequences) if total == 0: total = 1 empty_count = sum(1 for s in sequences if not s) start_counts = {} for seq in sequences: if seq: start_counts[seq[0]] = start_counts.get(seq[0], 0) + 1 prob = {} for s in G._succ: if s == G.sink: continue succ = list(G._succ[s]) if not succ: prob[s] = {} continue vals = [] for t in succ: if s == G.src: if t == G.sink: v = empty_count / total else: lab = G.label(t) base = lab.rsplit('_', 1)[0] if '_' in lab else lab count = start_counts.get(base, 0) copies = sum(1 for u in succ if G.label(u) == lab) v = (count / total) / max(copies, 1) vals.append(v) else: vals.append(random.random()) s_total = sum(vals) if s_total == 0: vals = [1.0 / len(vals)] * len(vals) else: vals = [v / s_total for v in vals] prob[s] = {t: v for t, v in zip(succ, vals)} for s in prob: for t in prob[s]: if prob[s][t] < 1e-10: prob[s][t] = 0.0 return prob def bw_iteration(prob, sequences, node_to_idx, n_states, all_nodes, G): """Single Baum-Welch iteration over all sequences.""" total_num = {} total_denom = {} for seq in sequences: if not seq: continue T = len(seq) obs = seq # which states can emit each observation? (keyed by base symbol) emit = {} for n in all_nodes: lab = G.label(n) if lab: base = lab.rsplit('_', 1)[0] if '_' in lab else lab emit.setdefault(base, []).append(n) # sink emits nothing sink = G.sink # Forward pass alpha = [{} for _ in range(T + 1)] alpha[0][G.src] = 1.0 for t in range(T): sym = obs[t] possible = emit.get(sym, []) for j in possible: total = 0.0 for i in alpha[t]: p_trans = prob.get(i, {}).get(j, 0.0) if p_trans > 0: total += alpha[t][i] * p_trans if total > 0: alpha[t + 1][j] = total # P(O | λ) po = 0.0 for i in alpha[T]: po += alpha[T][i] * prob.get(i, {}).get(sink, 0.0) if po == 0: continue # Backward pass beta = [{} for _ in range(T + 1)] for i in all_nodes: if prob.get(i, {}).get(sink, 0.0) > 0: beta[T][i] = prob[i][sink] for t in range(T - 1, -1, -1): sym = obs[t] if t < T else None possible = emit.get(sym, []) if sym else [] for i in alpha[t]: total = 0.0 for j in possible: p_trans = prob.get(i, {}).get(j, 0.0) if p_trans > 0 and j in beta[t + 1]: total += p_trans * beta[t + 1][j] if total > 0: beta[t][i] = total # Accumulate ξ and γ for t in range(T): sym_nxt = obs[t] possible = emit.get(sym_nxt, []) for i in alpha[t]: if i not in beta[t] or beta[t][i] == 0: continue for j in possible: p_trans = prob.get(i, {}).get(j, 0.0) if p_trans == 0 or j not in beta[t + 1] or beta[t + 1][j] == 0: continue xi = alpha[t][i] * p_trans * beta[t + 1][j] / po if xi > 1e-15: key = (i, j) total_num[key] = total_num.get(key, 0.0) + xi total_denom[i] = total_denom.get(i, 0.0) + xi # M-step: update probabilities for s in prob: for t in prob[s]: key = (s, t) d = total_denom.get(s, 0.0) if d > 1e-15 and key in total_num: prob[s][t] = total_num[key] / d else: prob[s][t] = 0.0 # Renormalize for s in prob: row_sum = sum(prob[s].values()) if row_sum > 1e-10: for t in prob[s]: prob[s][t] /= row_sum else: n_succ = len(prob[s]) for t in prob[s]: prob[s][t] = 1.0 / n_succ return prob def baum_welch(G, prob, sequences, iterations=10): """Baum-Welch EM training. Args: G: k-OA graph prob: dict[s][t] = transition probabilities sequences: list of token lists (bag, not set) iterations: number of EM iterations (full convergence) Returns: Updated prob dict """ all_nodes = list(G._succ.keys()) node_to_idx = {n: i for i, n in enumerate(all_nodes)} n_states = len(all_nodes) for _ in range(iterations): prob = bw_iteration(prob, sequences, node_to_idx, n_states, all_nodes, G) return prob def baum_welch_fixed(G, prob, sequences, iterations=2): """Baum-Welch with fixed small iteration count (for Disambiguate). ℓ = 2 for |Σ| ≤ 7, ℓ = 3 for |Σ| > 7. """ return baum_welch(G, prob, sequences, iterations)