grammar-inference-engine/bex/ensemble.py
tobjend 9045769d57
All checks were successful
ci/woodpecker/push/woodpecker Pipeline was successful
ci/woodpecker/pr/woodpecker Pipeline was successful
feat: core+outlier analysis via min_coverage parameter, 6 new tests
2026-07-01 15:09:10 +02:00

521 lines
17 KiB
Python

"""Ensemble grammar inference — run multiple algorithms, pick best by MDL scoring."""
import re
from .crx import CRX
from .idregex import idregex
from .kore import kOREInference
from .expr import alphabet
from .mdl import model_cost, mdl_score
def _parse_parts(expr):
"""Parse expression into a list of tokens for matching.
Each token: (type, value, quantifier)
type: 'symbol' | 'disj' | 'concat' | 'empty'
quantifier: '' | '?' | '+' | '+?'
"""
if not expr or expr == '':
return [('empty', '', '')]
if expr == 'ε':
return [('empty', '', '+?')]
# 1. Check if it's a concatenation (split outermost by '.')
# Must check BEFORE stripping trailing quantifier, because
# quantifiers belong to individual parts (e.g., a?.b+)
concat_parts = _split_outer(expr.strip(), '.')
if len(concat_parts) > 1:
children = []
for p in concat_parts:
children.extend(_parse_parts(p.strip()))
return [('concat', children, '')]
# 2. Now handle quantifier suffix on this single part
quantifier = ''
if expr.endswith('+?'):
quantifier = '+?'
expr = expr[:-2]
elif expr.endswith('*'):
quantifier = '*'
expr = expr[:-1]
elif expr.endswith('?'):
quantifier = '?'
expr = expr[:-1]
elif expr.endswith('+'):
quantifier = '+'
expr = expr[:-1]
# 3. Disjunction group: (a+b+c) for CRX or (a|b|c) for iDRegEx
if expr.startswith('(') and expr.endswith(')'):
inner = expr[1:-1]
# Try CRX-style (+) first, then iDRegEx-style (|)
disj_parts = _split_outer(inner, '+')
if len(disj_parts) <= 1:
disj_parts = _split_outer(inner, '|')
if len(disj_parts) > 1:
children = []
for p in disj_parts:
p = p.strip()
# Parse as a flat symbol (don't split dots — they're part of
# the symbol name, e.g. "community.docker.docker_image")
children.append(_parse_flat_symbol(p))
return [('disj', children, quantifier)]
# Single element inside parens: treat as flat symbol
return [_parse_flat_symbol(inner)]
# 4. Single symbol
if expr and expr not in ('', 'ε'):
return [('symbol', expr, quantifier)]
return []
def _parse_flat_symbol(s):
"""Parse a single symbol with optional quantifier, no dot splitting.
Unlike _parse_parts, this treats dots as part of the symbol name
(e.g. 'community.docker.docker_image' stays as one symbol).
"""
s = s.strip()
quantifier = ''
if s.endswith('+?'):
quantifier = '+?'
s = s[:-2]
elif s.endswith('*'):
quantifier = '*'
s = s[:-1]
elif s.endswith('?'):
quantifier = '?'
s = s[:-1]
elif s.endswith('+'):
quantifier = '+'
s = s[:-1]
if s and s not in ('', 'ε'):
return ('symbol', s, quantifier)
return ('empty', '', quantifier)
def _split_outer(s, sep):
"""Split on `sep` at the top level (not inside parentheses)."""
depth = 0
parts = []
cur = []
for ch in s:
if ch == '(':
depth += 1
cur.append(ch)
elif ch == ')':
depth -= 1
cur.append(ch)
elif ch == sep and depth == 0:
parts.append(''.join(cur))
cur = []
else:
cur.append(ch)
parts.append(''.join(cur))
return parts
def _match_possible(token, seq, pos):
"""Return all possible end positions after matching this token starting at pos."""
ttype, tval, tquant = token
positions = []
if ttype == 'empty':
positions.append(pos)
elif ttype == 'symbol':
if tquant in ('', '?'):
if pos < len(seq) and seq[pos] == tval:
positions.append(pos + 1)
if tquant == '?':
positions.append(pos)
elif tquant in ('+?', '*'):
positions.append(pos)
cnt = pos
while cnt < len(seq) and seq[cnt] == tval:
cnt += 1
positions.append(cnt)
elif tquant == '+':
if pos < len(seq) and seq[pos] == tval:
cnt = pos + 1
positions.append(cnt)
while cnt < len(seq) and seq[cnt] == tval:
cnt += 1
positions.append(cnt)
elif ttype == 'disj':
if tquant in ('', '?'):
for child in tval:
for ep in _match_possible(child, seq, pos):
positions.append(ep)
if tquant == '?':
positions.append(pos)
elif tquant in ('+?', '*'):
positions.append(pos)
for child in tval:
for ep in _match_possible(child, seq, pos):
if ep > pos:
positions.append(ep)
# After consuming one, recurse to try more
for ep2 in _match_possible(token, seq, ep):
if ep2 > ep:
positions.append(ep2)
elif tquant == '+':
for child in tval:
for ep in _match_possible(child, seq, pos):
if ep > pos:
positions.append(ep)
for ep2 in _match_possible(token, seq, ep):
if ep2 > ep:
positions.append(ep2)
elif ttype == 'concat':
# Match all children sequentially
def _match_seq(children, start):
cur = [start]
for child in children:
next_cur = []
for p in cur:
next_cur.extend(_match_possible(child, seq, p))
cur = next_cur
if not cur:
break
return cur
if tquant in ('', '?'):
positions.extend(_match_seq(tval, pos))
if tquant == '?':
positions.append(pos)
elif tquant in ('+?', '*'):
positions.append(pos)
inner_end = _match_seq(tval, pos)
for ep in inner_end:
if ep > pos:
positions.append(ep)
for ep2 in _match_possible(token, seq, ep):
if ep2 > ep:
positions.append(ep2)
elif tquant == '+':
inner_end = _match_seq(tval, pos)
for ep in inner_end:
if ep > pos:
positions.append(ep)
for ep2 in _match_possible(token, seq, ep):
if ep2 > ep:
positions.append(ep2)
return positions
def _match_tokens(tokens, seq, pos=0):
"""Try to match tokens against seq starting at pos. Returns max position or None."""
cur = [pos]
for token in tokens:
next_cur = []
for p in cur:
next_cur.extend(_match_possible(token, seq, p))
cur = next_cur
if not cur:
return None
return max(cur) if cur else pos
def _matches(grammar, sequence):
"""Check if a sequence matches the grammar."""
try:
tokens = _parse_parts(grammar.strip())
if not tokens:
return False
end = _match_tokens(tokens, sequence)
if end is None:
return False
return end == len(sequence)
except Exception:
return False
def _fit_score(grammar, seq):
"""Score how tightly a sequence fits: 1.0 = perfect match to core,
0.0 = mostly uses optional/repeated parts.
Instead of trying to parse the grammar structure (which is fragile),
this measures how well seq matches against the grammatical core by
comparing its symbol positions to the grammar's 'spine' — the symbols
that appear in all sequences.
"""
if not seq:
return 0.0
try:
# Strategy: parse grammar tokens, match seq, count what fraction
# of seq length is consumed by obligatory (non-?, non-+?) tokens.
tokens = _parse_parts(grammar.strip())
if not tokens or tokens[0][0] == 'empty':
return 0.0
def _classify_tokens(node):
"""Return (obligatory_count, optional_count) for this node."""
tt, tv, tq = node
if tt == 'symbol':
if tq in ('', '+'):
return (1, 0)
return (0, 1)
if tt == 'concat':
ob, op = 0, 0
for c in tv:
if c[0] == 'empty':
continue
o1, o2 = _classify_tokens(c)
ob += o1
op += o2
return (ob, op)
if tt == 'disj':
# Any alternative counts as optional
return (0, len(tv))
return (0, 0)
ob, op = _classify_tokens(tokens[0])
total = ob + op
if total == 0:
return 0.5
# Match seq and see how many symbols are actually consumed
end = _match_tokens(tokens, seq)
if end is None or end != len(seq):
return 0.0
# Fit = fraction of mandatory symbols / total mandatory+optional
# Penalizes sequences that lean heavily on optional parts
return max(0.0, 1.0 - (op / total))
except Exception:
return 0.0
def _symbol_rarity_score(seq, all_sequences):
"""Score a sequence by how rare its symbols are across the dataset.
1.0 = all symbols are common, 0.0 = mostly rare symbols.
"""
from collections import Counter
all_syms = Counter()
for s in all_sequences:
all_syms.update(s)
n = len(all_sequences)
scores = []
for sym in seq:
freq = all_syms.get(sym, 0) / n
scores.append(min(freq, 1.0))
return sum(scores) / len(scores) if scores else 0.0
def _find_core(sequences, min_coverage=0.8):
"""Find the core subset of sequences by iterative CRX + outlier removal.
Outlier detection uses symbol rarity: sequences with rare symbols
(appearing in few other sequences) are removed first.
Returns:
(core_grammar, core_sequences, outliers, fit_scores)
"""
if not sequences or min_coverage >= 1.0:
crx_g = CRX().infer(sequences)
return crx_g, sequences, [], []
from collections import Counter
all_syms = Counter()
for s in sequences:
all_syms.update(s)
n = len(sequences)
def _rarity(seq):
rare_count = sum(1 for sym in seq if all_syms.get(sym, 0) / n < 0.3)
return rare_count / max(len(seq), 1)
working = list(sequences)
removed_indices = []
crx = CRX()
for _ in range(50):
if len(working) < 3:
break
target = max(int(len(sequences) * min_coverage), 1)
if len(working) <= target:
break
# Score by rarity: most rare symbol → worst fit
scores = [(i, _rarity(seq)) for i, seq in enumerate(working)]
scores.sort(key=lambda x: -x[1]) # most rare first
# If all sequences have the same score, stop (no outliers to remove)
if len(scores) < 2 or scores[0][1] == scores[-1][1]:
break
worst_idx = scores[0][0]
removed_indices.append(working[worst_idx])
working = [s for i, s in enumerate(working) if i != worst_idx]
core_g = crx.infer(working) if working else None
return core_g, working, removed_indices, []
def mdl_score_simple(grammar, sequences):
"""MDL score from the paper: model_cost + Σ log₂(|L(r)| at length len(s)).
Lower is better. Uses the paper's definition from Bex et al.
model_cost = number of alphabet symbol occurrences in the expression.
data_cost = Σ log₂(|L(r)|) — penalizes overly general grammars.
"""
return mdl_score(grammar, sequences)
def _run_idregex(sequences, kmax, N):
"""Run standalone iDRegEx, return (grammar, score) or (None, inf)."""
g = idregex(sequences, kmax=kmax, N=N)
if g and g != '':
return g, mdl_score_simple(g, sequences)
return None, float('inf')
def _run_kore(sequences, kmax, N):
"""Run kOREInference (Algorithm 4 with MDL), return (grammar, score) or (None, inf)."""
kore = kOREInference(k_max=kmax, N=N)
result = kore.infer(sequences)
if result:
_, expr, _ = result
return expr, mdl_score_simple(expr, sequences)
return None, float('inf')
_ALGO_NAMES = {
'crx': 'CRX',
'idregex': 'iDRegEx',
'koreinference': 'kOREInference',
}
_ALGORITHMS = {
'crx': lambda s, k, n: (CRX().infer(s), mdl_score_simple(CRX().infer(s), s)),
'idregex': _run_idregex,
'koreinference': _run_kore,
}
def infer_ensemble(sequences, kmax=2, N=3, prefer=None, min_coverage=1.0):
"""Run all applicable algorithms and return the best by MDL score.
Args:
sequences: List of sequences, each a list of strings.
kmax: Maximum k for k-ORE inference (iDRegEx, kOREInference).
N: Number of random trials for k-ORE inference.
prefer: Optional — 'crx', 'idregex', or 'koreinference' to skip
ensemble and return only that algorithm's result.
min_coverage: When < 1.0, also runs CRX on the tightest core subset
of sequences. Outliers (worst-fitting) are iteratively
removed until at least this fraction remains. The core
grammar and outlier list are included in the response.
Returns:
dict with keys:
best: {algorithm, grammar, mdl_score}
all: [{algorithm, grammar, mdl_score}, ...]
why: str explaining the choice
core: (optional) {grammar, coverage, outliers} — only when
min_coverage < 1.0
"""
if prefer and prefer.lower() in _ALGORITHMS:
key = prefer.lower()
fn = _ALGORITHMS[key]
algo_name = _ALGO_NAMES.get(key, key)
g, score = fn(sequences, kmax, N)
if g and g != '':
return {
'best': {'algorithm': algo_name, 'grammar': g, 'mdl_score': round(score, 2)},
'all': [{'algorithm': algo_name, 'grammar': g, 'mdl_score': round(score, 2)}],
'why': f"Requested {algo_name} only.",
}
return {
'best': None,
'all': [],
'why': f"{algo_name} returned ∅ (no grammar found).",
}
results = []
# 1. CRX (always fast, always produces a result)
crx_g = CRX().infer(sequences)
crx_score = mdl_score_simple(crx_g, sequences) if crx_g and crx_g != '' else float('inf')
results.append(('CRX', crx_g if crx_g and crx_g != '' else '', crx_score))
# 2. iDRegEx (standalone, langsize-based)
idr_g, idr_score = _run_idregex(sequences, kmax, N)
if idr_g:
results.append(('iDRegEx', idr_g, idr_score))
# 3. kOREInference (Algorithm 4 with MDL scoring)
kore_g, kore_score = _run_kore(sequences, kmax, N)
if kore_g:
results.append(('kOREInference', kore_g, kore_score))
results = [r for r in results if r[1] and r[1] != '']
if not results:
base = {
'best': None,
'all': [],
'why': "No algorithm produced a non-empty grammar.",
}
if min_coverage < 1.0:
core_g, core_seqs, outliers, _ = _find_core(sequences, min_coverage)
base['core'] = {
'grammar': core_g,
'coverage': round(len(core_seqs) / max(len(sequences), 1), 2) if sequences else 0,
'outliers': outliers,
}
return base
results.sort(key=lambda x: x[2])
best = results[0]
all_results = [
{'algorithm': a, 'grammar': g, 'mdl_score': round(s, 2)}
for a, g, s in results
]
active = {r[0] for r in results}
why_parts = []
if len(results) == 1:
why_parts.append(f"Only {results[0][0]} produced a result.")
else:
scores_str = ', '.join(f"{r[0]}={r[2]:.1f}" for r in results)
why_parts.append(f"Scores: {scores_str}.")
match_strs = []
for r_algo, r_grammar, _ in results:
if r_grammar and r_grammar != '':
m = sum(1 for s in sequences if _matches(r_grammar, s))
match_strs.append(f"{r_algo}={m}/{len(sequences)}")
if match_strs:
why_parts.append(f"Match rates: {', '.join(match_strs)}.")
why_parts.append(f"{best[0]} selected (MDL score {best[2]:.1f}).")
result = {
'best': {
'algorithm': best[0],
'grammar': best[1],
'mdl_score': round(best[2], 2),
},
'all': all_results,
'why': ' '.join(why_parts),
}
# Core analysis when min_coverage < 1.0
if min_coverage < 1.0:
core_g, core_seqs, outliers, _ = _find_core(sequences, min_coverage)
result['core'] = {
'grammar': core_g,
'coverage': round(len(core_seqs) / max(len(sequences), 1), 2) if sequences else 0,
'outlier_count': len(outliers),
'outliers': outliers,
}
result['why'] += f' Core CRX ({min_coverage:.0%} coverage, {len(outliers)} outliers): {core_g}'
return result