grammar-inference-engine/bex/mdl.py
tobjend 0e2aec582b Grammar inference engine: CRX + iDRegEx ensemble with MDL scoring, MCP server, showcase, and blog post
- Ensemble inference (infer_ensemble) runs both CRX and iDRegEx, picks best by MDL
- CRX: CRX algorithm for wide coverage (accepts all sequences, large vocabulary)
- iDRegEx: iDRegEx for minimal core grammar (tightest common pattern)
- MDL scoring: fixed model_cost to count alphabet symbol occurrences, fixed dispatch order in _count_words_fast
- Fixed _match_tokens: rewritten as _match_possible with proper backtracking
- Fixed _parse_parts disjunction: children use _parse_flat_symbol to avoid dot-splitting
- MCP server: infer_best_grammar and infer_grammar tools
- Added prefer parameter (crx/idregex) to skip ensemble
- 28 passing tests
- SHOWCASE.md with Geerlingguy Galaxy demonstration
- blog_post.md with full technical deep-dive
2026-07-01 09:51:41 +02:00

198 lines
5.6 KiB
Python

"""MDL scoring for iDRegEx (Algorithm 4, arXiv 1004.2372)."""
import math
import functools
from .expr import alphabet
def model_cost(expr):
"""|r| — number of alphabet symbol occurrences in expression."""
import re
syms = alphabet(expr)
# Count each symbol by how many times it appears as a standalone word
count = 0
for s in syms:
# Count occurrences where symbol is bordered by operators or edges
count += len(re.findall(rf'(?<![a-zA-Z_]){re.escape(s)}(?![a-zA-Z_])', expr))
return count
def lang_size(expr, n=None):
"""Estimate |L(r)≤n| — number of words of length ≤ n in L(r).
Simple approximation based on expression structure.
"""
if not expr or expr == '':
return 0
if expr == 'ε':
return 1
n = n or (2 * model_cost(expr) + 1)
total = 0
for length in range(n + 1):
total += _count_words_fast(expr, length)
return total
@functools.lru_cache(maxsize=None)
def _count_words_fast(expr, length):
if length < 0:
return 0
if not expr or expr == '':
return 0
if expr == 'ε':
return 1 if length == 0 else 0
alpha = alphabet(expr)
if expr in alpha:
return 1 if length == 1 else 0
# 0. Concatenation: a.b.c — check FIRST so trailing quantifiers
# apply to each part individually, not the whole expression.
if '.' in expr:
parts = _split_disj_crx(expr, '.')
if len(parts) > 1:
return _count_concat(tuple(parts), length, 0)
# 1. Trailing quantifiers
if expr.endswith('+?'):
return _count_star(expr[:-2], length, min_count=0)
if expr.endswith('*'):
return _count_star(expr[:-1], length, min_count=0)
if expr.endswith('?') and not expr.endswith('+?'):
inner = expr[:-1]
return _count_words_fast(inner, length) + (1 if length == 0 else 0)
if expr.endswith('+') and not expr.endswith('+?'):
inner = expr[:-1]
return _count_star(inner, length, min_count=1)
# 2. Disjunction group: (a+b+c) for CRX or (a|b|c) for iDRegEx
if expr.startswith('(') and expr.endswith(')'):
inner = expr[1:-1]
parts = _split_disj_crx(inner, '+')
if len(parts) > 1:
return sum(_count_words_fast(p.strip(), length) for p in parts)
parts = _split_disj_crx(inner, '|')
if len(parts) > 1:
return sum(_count_words_fast(p.strip(), length) for p in parts)
return _count_words_fast(inner, length)
return 0
def _split_disj_crx(s, sep):
"""Split on `sep` at top depth (not inside nested parens)."""
depth = 0
parts = []
cur = []
for ch in s:
if ch == '(':
depth += 1
cur.append(ch)
elif ch == ')':
depth -= 1
cur.append(ch)
elif ch == sep and depth == 0:
parts.append(''.join(cur))
cur = []
else:
cur.append(ch)
parts.append(''.join(cur))
return parts
@functools.lru_cache(maxsize=None)
def _count_concat(parts_tuple, length, idx):
parts = list(parts_tuple)
if idx >= len(parts):
return 1 if length == 0 else 0
total = 0
for take in range(length + 1):
cnt = _count_words_fast(parts[idx], take)
if cnt:
total += cnt * _count_concat(parts_tuple, length - take, idx + 1)
return total
@functools.lru_cache(maxsize=None)
def _count_star(inner, length, min_count):
total = 0
for rep in range(min_count, length + 1):
total += _count_repeat(inner, rep, length)
return total
@functools.lru_cache(maxsize=None)
def _count_repeat(inner, rep, length):
if rep == 0:
return 1 if length == 0 else 0
total = 0
for take in range(length + 1):
cnt = _count_words_fast(inner, take)
if cnt:
total += cnt * _count_repeat(inner, rep - 1, length - take)
return total
def _split_disj(s):
depth = 0
parts = []
cur = []
for ch in s:
if ch == '(':
depth += 1
cur.append(ch)
elif ch == ')':
depth -= 1
cur.append(ch)
elif ch == '|' and depth == 0:
parts.append(''.join(cur))
cur = []
else:
cur.append(ch)
parts.append(''.join(cur))
return parts
def data_cost(expr, sequences):
"""MDL data cost: Σ_i log₂(|L_i(r)|) where |L_i(r)| is the number
of words of length len(seq_i) accepted by the grammar.
Lower cost = more specific grammar that still covers the data.
Exact computation is capped at max_len=50 to prevent combinatorial
explosion. Longer sequences use an alphabet-size upper bound.
"""
MAX_EXACT = 50
n = 2 * model_cost(expr) + 1
runtime_n = min(max(n, max((len(s) for s in sequences), default=0)), MAX_EXACT)
lang_sizes = [_count_words_fast(expr, l) for l in range(runtime_n + 1)]
alpha_size = len(alphabet(expr))
total_cost = 0.0
for seq in sequences:
length = len(seq)
if length <= runtime_n:
ls = lang_sizes[length]
if ls > 0:
total_cost += math.log2(ls)
else:
total_cost += length * math.log2(max(alpha_size, 1))
else:
total_cost += length * math.log2(max(alpha_size, 1))
return total_cost
def mdl_score(expr, sequences):
"""MDL = model cost + data cost."""
model = model_cost(expr)
data = data_cost(expr, sequences)
return model + data
# For backward compatibility
class MDLScorer:
def score(self, expr, sequences):
return mdl_score(expr, sequences)