grammar-inference-engine/bex/ensemble.py

350 lines
11 KiB
Python
Raw Normal View History

"""Ensemble grammar inference — run multiple algorithms, pick best by MDL scoring."""
import re
from .crx import CRX
from .idregex import idregex
from .expr import alphabet
from .mdl import model_cost, mdl_score
def _parse_parts(expr):
"""Parse expression into a list of tokens for matching.
Each token: (type, value, quantifier)
type: 'symbol' | 'disj' | 'concat' | 'empty'
quantifier: '' | '?' | '+' | '+?'
"""
if not expr or expr == '':
return [('empty', '', '')]
if expr == 'ε':
return [('empty', '', '+?')]
# 1. Check if it's a concatenation (split outermost by '.')
# Must check BEFORE stripping trailing quantifier, because
# quantifiers belong to individual parts (e.g., a?.b+)
concat_parts = _split_outer(expr.strip(), '.')
if len(concat_parts) > 1:
children = []
for p in concat_parts:
children.extend(_parse_parts(p.strip()))
return [('concat', children, '')]
# 2. Now handle quantifier suffix on this single part
quantifier = ''
if expr.endswith('+?'):
quantifier = '+?'
expr = expr[:-2]
elif expr.endswith('*'):
quantifier = '*'
expr = expr[:-1]
elif expr.endswith('?'):
quantifier = '?'
expr = expr[:-1]
elif expr.endswith('+'):
quantifier = '+'
expr = expr[:-1]
# 3. Disjunction group: (a+b+c) for CRX or (a|b|c) for iDRegEx
if expr.startswith('(') and expr.endswith(')'):
inner = expr[1:-1]
# Try CRX-style (+) first, then iDRegEx-style (|)
disj_parts = _split_outer(inner, '+')
if len(disj_parts) <= 1:
disj_parts = _split_outer(inner, '|')
if len(disj_parts) > 1:
children = []
for p in disj_parts:
p = p.strip()
# Parse as a flat symbol (don't split dots — they're part of
# the symbol name, e.g. "community.docker.docker_image")
children.append(_parse_flat_symbol(p))
return [('disj', children, quantifier)]
# Single element inside parens: treat as flat symbol
return [_parse_flat_symbol(inner)]
# 4. Single symbol
if expr and expr not in ('', 'ε'):
return [('symbol', expr, quantifier)]
return []
def _parse_flat_symbol(s):
"""Parse a single symbol with optional quantifier, no dot splitting.
Unlike _parse_parts, this treats dots as part of the symbol name
(e.g. 'community.docker.docker_image' stays as one symbol).
"""
s = s.strip()
quantifier = ''
if s.endswith('+?'):
quantifier = '+?'
s = s[:-2]
elif s.endswith('*'):
quantifier = '*'
s = s[:-1]
elif s.endswith('?'):
quantifier = '?'
s = s[:-1]
elif s.endswith('+'):
quantifier = '+'
s = s[:-1]
if s and s not in ('', 'ε'):
return ('symbol', s, quantifier)
return ('empty', '', quantifier)
def _split_outer(s, sep):
"""Split on `sep` at the top level (not inside parentheses)."""
depth = 0
parts = []
cur = []
for ch in s:
if ch == '(':
depth += 1
cur.append(ch)
elif ch == ')':
depth -= 1
cur.append(ch)
elif ch == sep and depth == 0:
parts.append(''.join(cur))
cur = []
else:
cur.append(ch)
parts.append(''.join(cur))
return parts
def _match_possible(token, seq, pos):
"""Return all possible end positions after matching this token starting at pos."""
ttype, tval, tquant = token
positions = []
if ttype == 'empty':
positions.append(pos)
elif ttype == 'symbol':
if tquant in ('', '?'):
if pos < len(seq) and seq[pos] == tval:
positions.append(pos + 1)
if tquant == '?':
positions.append(pos)
elif tquant in ('+?', '*'):
positions.append(pos)
cnt = pos
while cnt < len(seq) and seq[cnt] == tval:
cnt += 1
positions.append(cnt)
elif tquant == '+':
if pos < len(seq) and seq[pos] == tval:
cnt = pos + 1
positions.append(cnt)
while cnt < len(seq) and seq[cnt] == tval:
cnt += 1
positions.append(cnt)
elif ttype == 'disj':
if tquant in ('', '?'):
for child in tval:
for ep in _match_possible(child, seq, pos):
positions.append(ep)
if tquant == '?':
positions.append(pos)
elif tquant in ('+?', '*'):
positions.append(pos)
for child in tval:
for ep in _match_possible(child, seq, pos):
if ep > pos:
positions.append(ep)
# After consuming one, recurse to try more
for ep2 in _match_possible(token, seq, ep):
if ep2 > ep:
positions.append(ep2)
elif tquant == '+':
for child in tval:
for ep in _match_possible(child, seq, pos):
if ep > pos:
positions.append(ep)
for ep2 in _match_possible(token, seq, ep):
if ep2 > ep:
positions.append(ep2)
elif ttype == 'concat':
# Match all children sequentially
def _match_seq(children, start):
cur = [start]
for child in children:
next_cur = []
for p in cur:
next_cur.extend(_match_possible(child, seq, p))
cur = next_cur
if not cur:
break
return cur
if tquant in ('', '?'):
positions.extend(_match_seq(tval, pos))
if tquant == '?':
positions.append(pos)
elif tquant in ('+?', '*'):
positions.append(pos)
inner_end = _match_seq(tval, pos)
for ep in inner_end:
if ep > pos:
positions.append(ep)
for ep2 in _match_possible(token, seq, ep):
if ep2 > ep:
positions.append(ep2)
elif tquant == '+':
inner_end = _match_seq(tval, pos)
for ep in inner_end:
if ep > pos:
positions.append(ep)
for ep2 in _match_possible(token, seq, ep):
if ep2 > ep:
positions.append(ep2)
return positions
def _match_tokens(tokens, seq, pos=0):
"""Try to match tokens against seq starting at pos. Returns max position or None."""
cur = [pos]
for token in tokens:
next_cur = []
for p in cur:
next_cur.extend(_match_possible(token, seq, p))
cur = next_cur
if not cur:
return None
return max(cur) if cur else pos
def _matches(grammar, sequence):
"""Check if a sequence matches the grammar."""
try:
tokens = _parse_parts(grammar.strip())
if not tokens:
return False
end = _match_tokens(tokens, sequence)
if end is None:
return False
return end == len(sequence)
except Exception:
return False
def mdl_score_simple(grammar, sequences):
"""MDL score from the paper: model_cost + Σ log₂(|L(r)| at length len(s)).
Lower is better. Uses the paper's definition from Bex et al.
model_cost = number of alphabet symbol occurrences in the expression.
data_cost = Σ log₂(|L(r)|) penalizes overly general grammars.
"""
return mdl_score(grammar, sequences)
def infer_ensemble(sequences, kmax=2, N=3, prefer=None):
"""Run all applicable algorithms and return the best by MDL score.
Args:
sequences: List of sequences, each a list of strings.
kmax: Maximum k for iDRegEx k-ORE inference.
N: Number of EM iterations for iDRegEx.
prefer: Optional 'crx' or 'idregex' to skip ensemble and
return only that algorithm's result.
Returns:
dict with keys:
best: {algorithm, grammar, mdl_score}
all: [{algorithm, grammar, mdl_score}, ...]
why: str explaining the choice
"""
results = []
if prefer and prefer.lower() == 'idregex':
idr_g = idregex(sequences, kmax=kmax, N=N)
idr_score = mdl_score_simple(idr_g, sequences) if idr_g and idr_g != '' else float('inf')
if idr_g and idr_g != '':
results.append(('iDRegEx', idr_g, idr_score))
if not results:
return {
'best': None,
'all': [],
'why': "iDRegEx returned ∅ (no common core found).",
}
why = "Requested iDRegEx only."
return {
'best': {
'algorithm': 'iDRegEx',
'grammar': results[0][1],
'mdl_score': round(results[0][2], 2),
},
'all': [{'algorithm': 'iDRegEx', 'grammar': results[0][1], 'mdl_score': round(results[0][2], 2)}],
'why': why,
}
crx_g = CRX().infer(sequences)
crx_score = mdl_score_simple(crx_g, sequences)
results.append(('CRX', crx_g, crx_score))
if prefer and prefer.lower() == 'crx':
return {
'best': {
'algorithm': 'CRX',
'grammar': crx_g,
'mdl_score': round(crx_score, 2),
},
'all': [{'algorithm': 'CRX', 'grammar': crx_g, 'mdl_score': round(crx_score, 2)}],
'why': "Requested CRX only.",
}
idr_g = idregex(sequences, kmax=kmax, N=N)
if idr_g and idr_g != '':
idr_score = mdl_score_simple(idr_g, sequences)
results.append(('iDRegEx', idr_g, idr_score))
results.sort(key=lambda x: x[2])
best = results[0]
all_results = [
{'algorithm': a, 'grammar': g, 'mdl_score': round(s, 2)}
for a, g, s in results
]
crx_match = sum(1 for s in sequences if _matches(crx_g, s))
idr_match = sum(1 for s in sequences if _matches(idr_g, s)) if len(results) > 1 else 0
why_parts = []
if len(results) == 1:
why_parts.append(f"Only CRX produced a result (iDRegEx returned ∅).")
else:
why_parts.append(
f"{results[0][0]} (score {results[0][2]:.1f}) vs {results[1][0]} (score {results[1][2]:.1f})."
)
if crx_match == idr_match == len(sequences):
why_parts.append("Both grammars match all sequences.")
why_parts.append(
f"{results[0][0]} wins because it is more compact "
f"(lower model cost) while matching all data."
)
elif crx_match != idr_match:
why_parts.append(
f"CRX matches {crx_match}/{len(sequences)} sequences, "
f"iDRegEx matches {idr_match}/{len(sequences)}."
)
why_parts.append(
f"{best[0]} selected (MDL score {best[2]:.1f})."
)
return {
'best': {
'algorithm': best[0],
'grammar': best[1],
'mdl_score': round(best[2], 2),
},
'all': all_results,
'why': ' '.join(why_parts),
}