- Ensemble inference (infer_ensemble) runs both CRX and iDRegEx, picks best by MDL - CRX: CRX algorithm for wide coverage (accepts all sequences, large vocabulary) - iDRegEx: iDRegEx for minimal core grammar (tightest common pattern) - MDL scoring: fixed model_cost to count alphabet symbol occurrences, fixed dispatch order in _count_words_fast - Fixed _match_tokens: rewritten as _match_possible with proper backtracking - Fixed _parse_parts disjunction: children use _parse_flat_symbol to avoid dot-splitting - MCP server: infer_best_grammar and infer_grammar tools - Added prefer parameter (crx/idregex) to skip ensemble - 28 passing tests - SHOWCASE.md with Geerlingguy Galaxy demonstration - blog_post.md with full technical deep-dive
349 lines
11 KiB
Python
349 lines
11 KiB
Python
"""Ensemble grammar inference — run multiple algorithms, pick best by MDL scoring."""
|
|
|
|
import re
|
|
from .crx import CRX
|
|
from .idregex import idregex
|
|
from .expr import alphabet
|
|
from .mdl import model_cost, mdl_score
|
|
|
|
|
|
def _parse_parts(expr):
|
|
"""Parse expression into a list of tokens for matching.
|
|
|
|
Each token: (type, value, quantifier)
|
|
type: 'symbol' | 'disj' | 'concat' | 'empty'
|
|
quantifier: '' | '?' | '+' | '+?'
|
|
"""
|
|
if not expr or expr == '∅':
|
|
return [('empty', '', '')]
|
|
if expr == 'ε':
|
|
return [('empty', '', '+?')]
|
|
|
|
# 1. Check if it's a concatenation (split outermost by '.')
|
|
# Must check BEFORE stripping trailing quantifier, because
|
|
# quantifiers belong to individual parts (e.g., a?.b+)
|
|
concat_parts = _split_outer(expr.strip(), '.')
|
|
if len(concat_parts) > 1:
|
|
children = []
|
|
for p in concat_parts:
|
|
children.extend(_parse_parts(p.strip()))
|
|
return [('concat', children, '')]
|
|
|
|
# 2. Now handle quantifier suffix on this single part
|
|
quantifier = ''
|
|
if expr.endswith('+?'):
|
|
quantifier = '+?'
|
|
expr = expr[:-2]
|
|
elif expr.endswith('*'):
|
|
quantifier = '*'
|
|
expr = expr[:-1]
|
|
elif expr.endswith('?'):
|
|
quantifier = '?'
|
|
expr = expr[:-1]
|
|
elif expr.endswith('+'):
|
|
quantifier = '+'
|
|
expr = expr[:-1]
|
|
|
|
# 3. Disjunction group: (a+b+c) for CRX or (a|b|c) for iDRegEx
|
|
if expr.startswith('(') and expr.endswith(')'):
|
|
inner = expr[1:-1]
|
|
# Try CRX-style (+) first, then iDRegEx-style (|)
|
|
disj_parts = _split_outer(inner, '+')
|
|
if len(disj_parts) <= 1:
|
|
disj_parts = _split_outer(inner, '|')
|
|
if len(disj_parts) > 1:
|
|
children = []
|
|
for p in disj_parts:
|
|
p = p.strip()
|
|
# Parse as a flat symbol (don't split dots — they're part of
|
|
# the symbol name, e.g. "community.docker.docker_image")
|
|
children.append(_parse_flat_symbol(p))
|
|
return [('disj', children, quantifier)]
|
|
# Single element inside parens: treat as flat symbol
|
|
return [_parse_flat_symbol(inner)]
|
|
|
|
# 4. Single symbol
|
|
if expr and expr not in ('∅', 'ε'):
|
|
return [('symbol', expr, quantifier)]
|
|
|
|
return []
|
|
|
|
|
|
def _parse_flat_symbol(s):
|
|
"""Parse a single symbol with optional quantifier, no dot splitting.
|
|
|
|
Unlike _parse_parts, this treats dots as part of the symbol name
|
|
(e.g. 'community.docker.docker_image' stays as one symbol).
|
|
"""
|
|
s = s.strip()
|
|
quantifier = ''
|
|
if s.endswith('+?'):
|
|
quantifier = '+?'
|
|
s = s[:-2]
|
|
elif s.endswith('*'):
|
|
quantifier = '*'
|
|
s = s[:-1]
|
|
elif s.endswith('?'):
|
|
quantifier = '?'
|
|
s = s[:-1]
|
|
elif s.endswith('+'):
|
|
quantifier = '+'
|
|
s = s[:-1]
|
|
if s and s not in ('∅', 'ε'):
|
|
return ('symbol', s, quantifier)
|
|
return ('empty', '', quantifier)
|
|
|
|
|
|
def _split_outer(s, sep):
|
|
"""Split on `sep` at the top level (not inside parentheses)."""
|
|
depth = 0
|
|
parts = []
|
|
cur = []
|
|
for ch in s:
|
|
if ch == '(':
|
|
depth += 1
|
|
cur.append(ch)
|
|
elif ch == ')':
|
|
depth -= 1
|
|
cur.append(ch)
|
|
elif ch == sep and depth == 0:
|
|
parts.append(''.join(cur))
|
|
cur = []
|
|
else:
|
|
cur.append(ch)
|
|
parts.append(''.join(cur))
|
|
return parts
|
|
|
|
|
|
def _match_possible(token, seq, pos):
|
|
"""Return all possible end positions after matching this token starting at pos."""
|
|
ttype, tval, tquant = token
|
|
positions = []
|
|
|
|
if ttype == 'empty':
|
|
positions.append(pos)
|
|
|
|
elif ttype == 'symbol':
|
|
if tquant in ('', '?'):
|
|
if pos < len(seq) and seq[pos] == tval:
|
|
positions.append(pos + 1)
|
|
if tquant == '?':
|
|
positions.append(pos)
|
|
elif tquant in ('+?', '*'):
|
|
positions.append(pos)
|
|
cnt = pos
|
|
while cnt < len(seq) and seq[cnt] == tval:
|
|
cnt += 1
|
|
positions.append(cnt)
|
|
elif tquant == '+':
|
|
if pos < len(seq) and seq[pos] == tval:
|
|
cnt = pos + 1
|
|
positions.append(cnt)
|
|
while cnt < len(seq) and seq[cnt] == tval:
|
|
cnt += 1
|
|
positions.append(cnt)
|
|
|
|
elif ttype == 'disj':
|
|
if tquant in ('', '?'):
|
|
for child in tval:
|
|
for ep in _match_possible(child, seq, pos):
|
|
positions.append(ep)
|
|
if tquant == '?':
|
|
positions.append(pos)
|
|
elif tquant in ('+?', '*'):
|
|
positions.append(pos)
|
|
for child in tval:
|
|
for ep in _match_possible(child, seq, pos):
|
|
if ep > pos:
|
|
positions.append(ep)
|
|
# After consuming one, recurse to try more
|
|
for ep2 in _match_possible(token, seq, ep):
|
|
if ep2 > ep:
|
|
positions.append(ep2)
|
|
elif tquant == '+':
|
|
for child in tval:
|
|
for ep in _match_possible(child, seq, pos):
|
|
if ep > pos:
|
|
positions.append(ep)
|
|
for ep2 in _match_possible(token, seq, ep):
|
|
if ep2 > ep:
|
|
positions.append(ep2)
|
|
|
|
elif ttype == 'concat':
|
|
# Match all children sequentially
|
|
def _match_seq(children, start):
|
|
cur = [start]
|
|
for child in children:
|
|
next_cur = []
|
|
for p in cur:
|
|
next_cur.extend(_match_possible(child, seq, p))
|
|
cur = next_cur
|
|
if not cur:
|
|
break
|
|
return cur
|
|
if tquant in ('', '?'):
|
|
positions.extend(_match_seq(tval, pos))
|
|
if tquant == '?':
|
|
positions.append(pos)
|
|
elif tquant in ('+?', '*'):
|
|
positions.append(pos)
|
|
inner_end = _match_seq(tval, pos)
|
|
for ep in inner_end:
|
|
if ep > pos:
|
|
positions.append(ep)
|
|
for ep2 in _match_possible(token, seq, ep):
|
|
if ep2 > ep:
|
|
positions.append(ep2)
|
|
elif tquant == '+':
|
|
inner_end = _match_seq(tval, pos)
|
|
for ep in inner_end:
|
|
if ep > pos:
|
|
positions.append(ep)
|
|
for ep2 in _match_possible(token, seq, ep):
|
|
if ep2 > ep:
|
|
positions.append(ep2)
|
|
|
|
return positions
|
|
|
|
|
|
def _match_tokens(tokens, seq, pos=0):
|
|
"""Try to match tokens against seq starting at pos. Returns max position or None."""
|
|
cur = [pos]
|
|
for token in tokens:
|
|
next_cur = []
|
|
for p in cur:
|
|
next_cur.extend(_match_possible(token, seq, p))
|
|
cur = next_cur
|
|
if not cur:
|
|
return None
|
|
return max(cur) if cur else pos
|
|
|
|
|
|
def _matches(grammar, sequence):
|
|
"""Check if a sequence matches the grammar."""
|
|
try:
|
|
tokens = _parse_parts(grammar.strip())
|
|
if not tokens:
|
|
return False
|
|
end = _match_tokens(tokens, sequence)
|
|
if end is None:
|
|
return False
|
|
return end == len(sequence)
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
def mdl_score_simple(grammar, sequences):
|
|
"""MDL score from the paper: model_cost + Σ log₂(|L(r)| at length len(s)).
|
|
|
|
Lower is better. Uses the paper's definition from Bex et al.
|
|
model_cost = number of alphabet symbol occurrences in the expression.
|
|
data_cost = Σ log₂(|L(r)|) — penalizes overly general grammars.
|
|
"""
|
|
return mdl_score(grammar, sequences)
|
|
|
|
|
|
def infer_ensemble(sequences, kmax=2, N=3, prefer=None):
|
|
"""Run all applicable algorithms and return the best by MDL score.
|
|
|
|
Args:
|
|
sequences: List of sequences, each a list of strings.
|
|
kmax: Maximum k for iDRegEx k-ORE inference.
|
|
N: Number of EM iterations for iDRegEx.
|
|
prefer: Optional — 'crx' or 'idregex' to skip ensemble and
|
|
return only that algorithm's result.
|
|
|
|
Returns:
|
|
dict with keys:
|
|
best: {algorithm, grammar, mdl_score}
|
|
all: [{algorithm, grammar, mdl_score}, ...]
|
|
why: str explaining the choice
|
|
"""
|
|
results = []
|
|
|
|
if prefer and prefer.lower() == 'idregex':
|
|
idr_g = idregex(sequences, kmax=kmax, N=N)
|
|
idr_score = mdl_score_simple(idr_g, sequences) if idr_g and idr_g != '∅' else float('inf')
|
|
if idr_g and idr_g != '∅':
|
|
results.append(('iDRegEx', idr_g, idr_score))
|
|
if not results:
|
|
return {
|
|
'best': None,
|
|
'all': [],
|
|
'why': "iDRegEx returned ∅ (no common core found).",
|
|
}
|
|
why = "Requested iDRegEx only."
|
|
return {
|
|
'best': {
|
|
'algorithm': 'iDRegEx',
|
|
'grammar': results[0][1],
|
|
'mdl_score': round(results[0][2], 2),
|
|
},
|
|
'all': [{'algorithm': 'iDRegEx', 'grammar': results[0][1], 'mdl_score': round(results[0][2], 2)}],
|
|
'why': why,
|
|
}
|
|
|
|
crx_g = CRX().infer(sequences)
|
|
crx_score = mdl_score_simple(crx_g, sequences)
|
|
results.append(('CRX', crx_g, crx_score))
|
|
|
|
if prefer and prefer.lower() == 'crx':
|
|
return {
|
|
'best': {
|
|
'algorithm': 'CRX',
|
|
'grammar': crx_g,
|
|
'mdl_score': round(crx_score, 2),
|
|
},
|
|
'all': [{'algorithm': 'CRX', 'grammar': crx_g, 'mdl_score': round(crx_score, 2)}],
|
|
'why': "Requested CRX only.",
|
|
}
|
|
|
|
idr_g = idregex(sequences, kmax=kmax, N=N)
|
|
if idr_g and idr_g != '∅':
|
|
idr_score = mdl_score_simple(idr_g, sequences)
|
|
results.append(('iDRegEx', idr_g, idr_score))
|
|
|
|
results.sort(key=lambda x: x[2])
|
|
|
|
best = results[0]
|
|
all_results = [
|
|
{'algorithm': a, 'grammar': g, 'mdl_score': round(s, 2)}
|
|
for a, g, s in results
|
|
]
|
|
|
|
crx_match = sum(1 for s in sequences if _matches(crx_g, s))
|
|
idr_match = sum(1 for s in sequences if _matches(idr_g, s)) if len(results) > 1 else 0
|
|
|
|
why_parts = []
|
|
if len(results) == 1:
|
|
why_parts.append(f"Only CRX produced a result (iDRegEx returned ∅).")
|
|
else:
|
|
why_parts.append(
|
|
f"{results[0][0]} (score {results[0][2]:.1f}) vs {results[1][0]} (score {results[1][2]:.1f})."
|
|
)
|
|
|
|
if crx_match == idr_match == len(sequences):
|
|
why_parts.append("Both grammars match all sequences.")
|
|
why_parts.append(
|
|
f"{results[0][0]} wins because it is more compact "
|
|
f"(lower model cost) while matching all data."
|
|
)
|
|
elif crx_match != idr_match:
|
|
why_parts.append(
|
|
f"CRX matches {crx_match}/{len(sequences)} sequences, "
|
|
f"iDRegEx matches {idr_match}/{len(sequences)}."
|
|
)
|
|
|
|
why_parts.append(
|
|
f"{best[0]} selected (MDL score {best[2]:.1f})."
|
|
)
|
|
|
|
return {
|
|
'best': {
|
|
'algorithm': best[0],
|
|
'grammar': best[1],
|
|
'mdl_score': round(best[2], 2),
|
|
},
|
|
'all': all_results,
|
|
'why': ' '.join(why_parts),
|
|
}
|