"""MDL scoring for iDRegEx (Algorithm 4, arXiv 1004.2372).""" import math from .expr import alphabet def model_cost(expr): """|r| — number of alphabet symbol occurrences in expression.""" import re cleaned = re.sub(r'[+?*()|.]', '', expr) cleaned = re.sub(r'_\d+', '', cleaned) cleaned = re.sub(r'[ε∅]', '', cleaned) return len(cleaned) def lang_size(expr, n=None): """Estimate |L(r)≤n| — number of words of length ≤ n in L(r). Simple approximation based on expression structure. """ if not expr or expr == '∅': return 0 if expr == 'ε': return 1 n = n or (2 * model_cost(expr) + 1) total = 0 for length in range(n + 1): total += _count_words_fast(expr, length) return total def _count_words_fast(expr, length): if length < 0: return 0 if not expr or expr == '∅': return 0 if expr == 'ε': return 1 if length == 0 else 0 alpha = alphabet(expr) if expr in alpha: return 1 if length == 1 else 0 if '+' in expr: inner = expr.rstrip('+') if inner.endswith('?'): inner = inner[:-1] return _count_star(inner, length, min_count=1) if expr.endswith('?'): inner = expr[:-1] return _count_words_fast(inner, length) + (1 if length == 0 else 0) if expr.startswith('(') and '|' in expr: parts = _split_disj(expr[1:-1]) return sum(_count_words_fast(p.strip(), length) for p in parts) if '.' in expr: parts = expr.split('.') return _count_concat(parts, length, 0) return 0 def _count_concat(parts, length, idx): if idx >= len(parts): return 1 if length == 0 else 0 total = 0 for take in range(length + 1): cnt = _count_words_fast(parts[idx], take) if cnt: total += cnt * _count_concat(parts, length - take, idx + 1) return total def _count_star(inner, length, min_count): total = 0 for rep in range(min_count, length + 1): total += _count_repeat(inner, rep, length) return total def _count_repeat(inner, rep, length): if rep == 0: return 1 if length == 0 else 0 total = 0 for take in range(length + 1): cnt = _count_words_fast(inner, take) if cnt: total += cnt * _count_repeat(inner, rep - 1, length - take) return total def _split_disj(s): depth = 0 parts = [] cur = [] for ch in s: if ch == '(': depth += 1 cur.append(ch) elif ch == ')': depth -= 1 cur.append(ch) elif ch == '|' and depth == 0: parts.append(''.join(cur)) cur = [] else: cur.append(ch) parts.append(''.join(cur)) return parts def data_cost(expr, sequences): """MDL data cost: Σ_i log₂(|L=i(r)| / |S=i|) adjusted. Simplified form: for each word in S, cost = log₂(lang_size of all words of that length). """ n = 2 * model_cost(expr) + 1 total_cost = 0.0 for seq in sequences: length = len(seq) if length <= n: lang_at_len = _count_words_fast(expr, length) if lang_at_len > 0: total_cost += math.log2(lang_at_len) if lang_at_len > 0 else 0 return total_cost def mdl_score(expr, sequences): """MDL = model cost + data cost.""" model = model_cost(expr) data = data_cost(expr, sequences) return model + data # For backward compatibility class MDLScorer: def score(self, expr, sequences): return mdl_score(expr, sequences)