"""MDL scoring for iDRegEx (Algorithm 4, arXiv 1004.2372).""" import math import functools from .expr import alphabet def model_cost(expr): """|r| — number of alphabet symbol occurrences in expression.""" import re syms = alphabet(expr) # Count each symbol by how many times it appears as a standalone word count = 0 for s in syms: # Count occurrences where symbol is bordered by operators or edges count += len(re.findall(rf'(? 1: return _count_concat(tuple(parts), length, 0) # 1. Trailing quantifiers if expr.endswith('+?'): return _count_star(expr[:-2], length, min_count=0) if expr.endswith('*'): return _count_star(expr[:-1], length, min_count=0) if expr.endswith('?') and not expr.endswith('+?'): inner = expr[:-1] return _count_words_fast(inner, length) + (1 if length == 0 else 0) if expr.endswith('+') and not expr.endswith('+?'): inner = expr[:-1] return _count_star(inner, length, min_count=1) # 2. Disjunction group: (a+b+c) for CRX or (a|b|c) for iDRegEx if expr.startswith('(') and expr.endswith(')'): inner = expr[1:-1] parts = _split_disj_crx(inner, '+') if len(parts) > 1: return sum(_count_words_fast(p.strip(), length) for p in parts) parts = _split_disj_crx(inner, '|') if len(parts) > 1: return sum(_count_words_fast(p.strip(), length) for p in parts) return _count_words_fast(inner, length) return 0 def _split_disj_crx(s, sep): """Split on `sep` at top depth (not inside nested parens).""" depth = 0 parts = [] cur = [] for ch in s: if ch == '(': depth += 1 cur.append(ch) elif ch == ')': depth -= 1 cur.append(ch) elif ch == sep and depth == 0: parts.append(''.join(cur)) cur = [] else: cur.append(ch) parts.append(''.join(cur)) return parts @functools.lru_cache(maxsize=None) def _count_concat(parts_tuple, length, idx): parts = list(parts_tuple) if idx >= len(parts): return 1 if length == 0 else 0 total = 0 for take in range(length + 1): cnt = _count_words_fast(parts[idx], take) if cnt: total += cnt * _count_concat(parts_tuple, length - take, idx + 1) return total @functools.lru_cache(maxsize=None) def _count_star(inner, length, min_count): total = 0 for rep in range(min_count, length + 1): total += _count_repeat(inner, rep, length) return total @functools.lru_cache(maxsize=None) def _count_repeat(inner, rep, length): if rep == 0: return 1 if length == 0 else 0 total = 0 for take in range(length + 1): cnt = _count_words_fast(inner, take) if cnt: total += cnt * _count_repeat(inner, rep - 1, length - take) return total def _split_disj(s): depth = 0 parts = [] cur = [] for ch in s: if ch == '(': depth += 1 cur.append(ch) elif ch == ')': depth -= 1 cur.append(ch) elif ch == '|' and depth == 0: parts.append(''.join(cur)) cur = [] else: cur.append(ch) parts.append(''.join(cur)) return parts def data_cost(expr, sequences): """MDL data cost: Σ_i log₂(|L_i(r)|) where |L_i(r)| is the number of words of length len(seq_i) accepted by the grammar. Lower cost = more specific grammar that still covers the data. Exact computation is capped at max_len=50 to prevent combinatorial explosion. Longer sequences use an alphabet-size upper bound. """ MAX_EXACT = 50 n = 2 * model_cost(expr) + 1 runtime_n = min(max(n, max((len(s) for s in sequences), default=0)), MAX_EXACT) lang_sizes = [_count_words_fast(expr, l) for l in range(runtime_n + 1)] alpha_size = len(alphabet(expr)) total_cost = 0.0 for seq in sequences: length = len(seq) if length <= runtime_n: ls = lang_sizes[length] if ls > 0: total_cost += math.log2(ls) else: total_cost += length * math.log2(max(alpha_size, 1)) else: total_cost += length * math.log2(max(alpha_size, 1)) return total_cost def mdl_score(expr, sequences): """MDL = model cost + data cost.""" model = model_cost(expr) data = data_cost(expr, sequences) return model + data # For backward compatibility class MDLScorer: def score(self, expr, sequences): return mdl_score(expr, sequences)