521 lines
17 KiB
Python
521 lines
17 KiB
Python
"""Ensemble grammar inference — run multiple algorithms, pick best by MDL scoring."""
|
|
|
|
import re
|
|
from .crx import CRX
|
|
from .idregex import idregex
|
|
from .kore import kOREInference
|
|
from .expr import alphabet
|
|
from .mdl import model_cost, mdl_score
|
|
|
|
|
|
def _parse_parts(expr):
|
|
"""Parse expression into a list of tokens for matching.
|
|
|
|
Each token: (type, value, quantifier)
|
|
type: 'symbol' | 'disj' | 'concat' | 'empty'
|
|
quantifier: '' | '?' | '+' | '+?'
|
|
"""
|
|
if not expr or expr == '∅':
|
|
return [('empty', '', '')]
|
|
if expr == 'ε':
|
|
return [('empty', '', '+?')]
|
|
|
|
# 1. Check if it's a concatenation (split outermost by '.')
|
|
# Must check BEFORE stripping trailing quantifier, because
|
|
# quantifiers belong to individual parts (e.g., a?.b+)
|
|
concat_parts = _split_outer(expr.strip(), '.')
|
|
if len(concat_parts) > 1:
|
|
children = []
|
|
for p in concat_parts:
|
|
children.extend(_parse_parts(p.strip()))
|
|
return [('concat', children, '')]
|
|
|
|
# 2. Now handle quantifier suffix on this single part
|
|
quantifier = ''
|
|
if expr.endswith('+?'):
|
|
quantifier = '+?'
|
|
expr = expr[:-2]
|
|
elif expr.endswith('*'):
|
|
quantifier = '*'
|
|
expr = expr[:-1]
|
|
elif expr.endswith('?'):
|
|
quantifier = '?'
|
|
expr = expr[:-1]
|
|
elif expr.endswith('+'):
|
|
quantifier = '+'
|
|
expr = expr[:-1]
|
|
|
|
# 3. Disjunction group: (a+b+c) for CRX or (a|b|c) for iDRegEx
|
|
if expr.startswith('(') and expr.endswith(')'):
|
|
inner = expr[1:-1]
|
|
# Try CRX-style (+) first, then iDRegEx-style (|)
|
|
disj_parts = _split_outer(inner, '+')
|
|
if len(disj_parts) <= 1:
|
|
disj_parts = _split_outer(inner, '|')
|
|
if len(disj_parts) > 1:
|
|
children = []
|
|
for p in disj_parts:
|
|
p = p.strip()
|
|
# Parse as a flat symbol (don't split dots — they're part of
|
|
# the symbol name, e.g. "community.docker.docker_image")
|
|
children.append(_parse_flat_symbol(p))
|
|
return [('disj', children, quantifier)]
|
|
# Single element inside parens: treat as flat symbol
|
|
return [_parse_flat_symbol(inner)]
|
|
|
|
# 4. Single symbol
|
|
if expr and expr not in ('∅', 'ε'):
|
|
return [('symbol', expr, quantifier)]
|
|
|
|
return []
|
|
|
|
|
|
def _parse_flat_symbol(s):
|
|
"""Parse a single symbol with optional quantifier, no dot splitting.
|
|
|
|
Unlike _parse_parts, this treats dots as part of the symbol name
|
|
(e.g. 'community.docker.docker_image' stays as one symbol).
|
|
"""
|
|
s = s.strip()
|
|
quantifier = ''
|
|
if s.endswith('+?'):
|
|
quantifier = '+?'
|
|
s = s[:-2]
|
|
elif s.endswith('*'):
|
|
quantifier = '*'
|
|
s = s[:-1]
|
|
elif s.endswith('?'):
|
|
quantifier = '?'
|
|
s = s[:-1]
|
|
elif s.endswith('+'):
|
|
quantifier = '+'
|
|
s = s[:-1]
|
|
if s and s not in ('∅', 'ε'):
|
|
return ('symbol', s, quantifier)
|
|
return ('empty', '', quantifier)
|
|
|
|
|
|
def _split_outer(s, sep):
|
|
"""Split on `sep` at the top level (not inside parentheses)."""
|
|
depth = 0
|
|
parts = []
|
|
cur = []
|
|
for ch in s:
|
|
if ch == '(':
|
|
depth += 1
|
|
cur.append(ch)
|
|
elif ch == ')':
|
|
depth -= 1
|
|
cur.append(ch)
|
|
elif ch == sep and depth == 0:
|
|
parts.append(''.join(cur))
|
|
cur = []
|
|
else:
|
|
cur.append(ch)
|
|
parts.append(''.join(cur))
|
|
return parts
|
|
|
|
|
|
def _match_possible(token, seq, pos):
|
|
"""Return all possible end positions after matching this token starting at pos."""
|
|
ttype, tval, tquant = token
|
|
positions = []
|
|
|
|
if ttype == 'empty':
|
|
positions.append(pos)
|
|
|
|
elif ttype == 'symbol':
|
|
if tquant in ('', '?'):
|
|
if pos < len(seq) and seq[pos] == tval:
|
|
positions.append(pos + 1)
|
|
if tquant == '?':
|
|
positions.append(pos)
|
|
elif tquant in ('+?', '*'):
|
|
positions.append(pos)
|
|
cnt = pos
|
|
while cnt < len(seq) and seq[cnt] == tval:
|
|
cnt += 1
|
|
positions.append(cnt)
|
|
elif tquant == '+':
|
|
if pos < len(seq) and seq[pos] == tval:
|
|
cnt = pos + 1
|
|
positions.append(cnt)
|
|
while cnt < len(seq) and seq[cnt] == tval:
|
|
cnt += 1
|
|
positions.append(cnt)
|
|
|
|
elif ttype == 'disj':
|
|
if tquant in ('', '?'):
|
|
for child in tval:
|
|
for ep in _match_possible(child, seq, pos):
|
|
positions.append(ep)
|
|
if tquant == '?':
|
|
positions.append(pos)
|
|
elif tquant in ('+?', '*'):
|
|
positions.append(pos)
|
|
for child in tval:
|
|
for ep in _match_possible(child, seq, pos):
|
|
if ep > pos:
|
|
positions.append(ep)
|
|
# After consuming one, recurse to try more
|
|
for ep2 in _match_possible(token, seq, ep):
|
|
if ep2 > ep:
|
|
positions.append(ep2)
|
|
elif tquant == '+':
|
|
for child in tval:
|
|
for ep in _match_possible(child, seq, pos):
|
|
if ep > pos:
|
|
positions.append(ep)
|
|
for ep2 in _match_possible(token, seq, ep):
|
|
if ep2 > ep:
|
|
positions.append(ep2)
|
|
|
|
elif ttype == 'concat':
|
|
# Match all children sequentially
|
|
def _match_seq(children, start):
|
|
cur = [start]
|
|
for child in children:
|
|
next_cur = []
|
|
for p in cur:
|
|
next_cur.extend(_match_possible(child, seq, p))
|
|
cur = next_cur
|
|
if not cur:
|
|
break
|
|
return cur
|
|
if tquant in ('', '?'):
|
|
positions.extend(_match_seq(tval, pos))
|
|
if tquant == '?':
|
|
positions.append(pos)
|
|
elif tquant in ('+?', '*'):
|
|
positions.append(pos)
|
|
inner_end = _match_seq(tval, pos)
|
|
for ep in inner_end:
|
|
if ep > pos:
|
|
positions.append(ep)
|
|
for ep2 in _match_possible(token, seq, ep):
|
|
if ep2 > ep:
|
|
positions.append(ep2)
|
|
elif tquant == '+':
|
|
inner_end = _match_seq(tval, pos)
|
|
for ep in inner_end:
|
|
if ep > pos:
|
|
positions.append(ep)
|
|
for ep2 in _match_possible(token, seq, ep):
|
|
if ep2 > ep:
|
|
positions.append(ep2)
|
|
|
|
return positions
|
|
|
|
|
|
def _match_tokens(tokens, seq, pos=0):
|
|
"""Try to match tokens against seq starting at pos. Returns max position or None."""
|
|
cur = [pos]
|
|
for token in tokens:
|
|
next_cur = []
|
|
for p in cur:
|
|
next_cur.extend(_match_possible(token, seq, p))
|
|
cur = next_cur
|
|
if not cur:
|
|
return None
|
|
return max(cur) if cur else pos
|
|
|
|
|
|
def _matches(grammar, sequence):
|
|
"""Check if a sequence matches the grammar."""
|
|
try:
|
|
tokens = _parse_parts(grammar.strip())
|
|
if not tokens:
|
|
return False
|
|
end = _match_tokens(tokens, sequence)
|
|
if end is None:
|
|
return False
|
|
return end == len(sequence)
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
def _fit_score(grammar, seq):
|
|
"""Score how tightly a sequence fits: 1.0 = perfect match to core,
|
|
0.0 = mostly uses optional/repeated parts.
|
|
|
|
Instead of trying to parse the grammar structure (which is fragile),
|
|
this measures how well seq matches against the grammatical core by
|
|
comparing its symbol positions to the grammar's 'spine' — the symbols
|
|
that appear in all sequences.
|
|
"""
|
|
if not seq:
|
|
return 0.0
|
|
try:
|
|
# Strategy: parse grammar tokens, match seq, count what fraction
|
|
# of seq length is consumed by obligatory (non-?, non-+?) tokens.
|
|
tokens = _parse_parts(grammar.strip())
|
|
if not tokens or tokens[0][0] == 'empty':
|
|
return 0.0
|
|
|
|
def _classify_tokens(node):
|
|
"""Return (obligatory_count, optional_count) for this node."""
|
|
tt, tv, tq = node
|
|
if tt == 'symbol':
|
|
if tq in ('', '+'):
|
|
return (1, 0)
|
|
return (0, 1)
|
|
if tt == 'concat':
|
|
ob, op = 0, 0
|
|
for c in tv:
|
|
if c[0] == 'empty':
|
|
continue
|
|
o1, o2 = _classify_tokens(c)
|
|
ob += o1
|
|
op += o2
|
|
return (ob, op)
|
|
if tt == 'disj':
|
|
# Any alternative counts as optional
|
|
return (0, len(tv))
|
|
return (0, 0)
|
|
|
|
ob, op = _classify_tokens(tokens[0])
|
|
total = ob + op
|
|
if total == 0:
|
|
return 0.5
|
|
|
|
# Match seq and see how many symbols are actually consumed
|
|
end = _match_tokens(tokens, seq)
|
|
if end is None or end != len(seq):
|
|
return 0.0
|
|
|
|
# Fit = fraction of mandatory symbols / total mandatory+optional
|
|
# Penalizes sequences that lean heavily on optional parts
|
|
return max(0.0, 1.0 - (op / total))
|
|
except Exception:
|
|
return 0.0
|
|
|
|
|
|
def _symbol_rarity_score(seq, all_sequences):
|
|
"""Score a sequence by how rare its symbols are across the dataset.
|
|
1.0 = all symbols are common, 0.0 = mostly rare symbols.
|
|
"""
|
|
from collections import Counter
|
|
all_syms = Counter()
|
|
for s in all_sequences:
|
|
all_syms.update(s)
|
|
n = len(all_sequences)
|
|
scores = []
|
|
for sym in seq:
|
|
freq = all_syms.get(sym, 0) / n
|
|
scores.append(min(freq, 1.0))
|
|
return sum(scores) / len(scores) if scores else 0.0
|
|
|
|
|
|
def _find_core(sequences, min_coverage=0.8):
|
|
"""Find the core subset of sequences by iterative CRX + outlier removal.
|
|
|
|
Outlier detection uses symbol rarity: sequences with rare symbols
|
|
(appearing in few other sequences) are removed first.
|
|
|
|
Returns:
|
|
(core_grammar, core_sequences, outliers, fit_scores)
|
|
"""
|
|
if not sequences or min_coverage >= 1.0:
|
|
crx_g = CRX().infer(sequences)
|
|
return crx_g, sequences, [], []
|
|
|
|
from collections import Counter
|
|
all_syms = Counter()
|
|
for s in sequences:
|
|
all_syms.update(s)
|
|
n = len(sequences)
|
|
|
|
def _rarity(seq):
|
|
rare_count = sum(1 for sym in seq if all_syms.get(sym, 0) / n < 0.3)
|
|
return rare_count / max(len(seq), 1)
|
|
|
|
working = list(sequences)
|
|
removed_indices = []
|
|
crx = CRX()
|
|
|
|
for _ in range(50):
|
|
if len(working) < 3:
|
|
break
|
|
|
|
target = max(int(len(sequences) * min_coverage), 1)
|
|
if len(working) <= target:
|
|
break
|
|
|
|
# Score by rarity: most rare symbol → worst fit
|
|
scores = [(i, _rarity(seq)) for i, seq in enumerate(working)]
|
|
scores.sort(key=lambda x: -x[1]) # most rare first
|
|
|
|
# If all sequences have the same score, stop (no outliers to remove)
|
|
if len(scores) < 2 or scores[0][1] == scores[-1][1]:
|
|
break
|
|
|
|
worst_idx = scores[0][0]
|
|
removed_indices.append(working[worst_idx])
|
|
working = [s for i, s in enumerate(working) if i != worst_idx]
|
|
|
|
core_g = crx.infer(working) if working else None
|
|
return core_g, working, removed_indices, []
|
|
|
|
|
|
def mdl_score_simple(grammar, sequences):
|
|
"""MDL score from the paper: model_cost + Σ log₂(|L(r)| at length len(s)).
|
|
|
|
Lower is better. Uses the paper's definition from Bex et al.
|
|
model_cost = number of alphabet symbol occurrences in the expression.
|
|
data_cost = Σ log₂(|L(r)|) — penalizes overly general grammars.
|
|
"""
|
|
return mdl_score(grammar, sequences)
|
|
|
|
|
|
def _run_idregex(sequences, kmax, N):
|
|
"""Run standalone iDRegEx, return (grammar, score) or (None, inf)."""
|
|
g = idregex(sequences, kmax=kmax, N=N)
|
|
if g and g != '∅':
|
|
return g, mdl_score_simple(g, sequences)
|
|
return None, float('inf')
|
|
|
|
|
|
def _run_kore(sequences, kmax, N):
|
|
"""Run kOREInference (Algorithm 4 with MDL), return (grammar, score) or (None, inf)."""
|
|
kore = kOREInference(k_max=kmax, N=N)
|
|
result = kore.infer(sequences)
|
|
if result:
|
|
_, expr, _ = result
|
|
return expr, mdl_score_simple(expr, sequences)
|
|
return None, float('inf')
|
|
|
|
|
|
_ALGO_NAMES = {
|
|
'crx': 'CRX',
|
|
'idregex': 'iDRegEx',
|
|
'koreinference': 'kOREInference',
|
|
}
|
|
|
|
|
|
_ALGORITHMS = {
|
|
'crx': lambda s, k, n: (CRX().infer(s), mdl_score_simple(CRX().infer(s), s)),
|
|
'idregex': _run_idregex,
|
|
'koreinference': _run_kore,
|
|
}
|
|
|
|
|
|
def infer_ensemble(sequences, kmax=2, N=3, prefer=None, min_coverage=1.0):
|
|
"""Run all applicable algorithms and return the best by MDL score.
|
|
|
|
Args:
|
|
sequences: List of sequences, each a list of strings.
|
|
kmax: Maximum k for k-ORE inference (iDRegEx, kOREInference).
|
|
N: Number of random trials for k-ORE inference.
|
|
prefer: Optional — 'crx', 'idregex', or 'koreinference' to skip
|
|
ensemble and return only that algorithm's result.
|
|
min_coverage: When < 1.0, also runs CRX on the tightest core subset
|
|
of sequences. Outliers (worst-fitting) are iteratively
|
|
removed until at least this fraction remains. The core
|
|
grammar and outlier list are included in the response.
|
|
|
|
Returns:
|
|
dict with keys:
|
|
best: {algorithm, grammar, mdl_score}
|
|
all: [{algorithm, grammar, mdl_score}, ...]
|
|
why: str explaining the choice
|
|
core: (optional) {grammar, coverage, outliers} — only when
|
|
min_coverage < 1.0
|
|
"""
|
|
if prefer and prefer.lower() in _ALGORITHMS:
|
|
key = prefer.lower()
|
|
fn = _ALGORITHMS[key]
|
|
algo_name = _ALGO_NAMES.get(key, key)
|
|
g, score = fn(sequences, kmax, N)
|
|
if g and g != '∅':
|
|
return {
|
|
'best': {'algorithm': algo_name, 'grammar': g, 'mdl_score': round(score, 2)},
|
|
'all': [{'algorithm': algo_name, 'grammar': g, 'mdl_score': round(score, 2)}],
|
|
'why': f"Requested {algo_name} only.",
|
|
}
|
|
return {
|
|
'best': None,
|
|
'all': [],
|
|
'why': f"{algo_name} returned ∅ (no grammar found).",
|
|
}
|
|
|
|
results = []
|
|
|
|
# 1. CRX (always fast, always produces a result)
|
|
crx_g = CRX().infer(sequences)
|
|
crx_score = mdl_score_simple(crx_g, sequences) if crx_g and crx_g != '∅' else float('inf')
|
|
results.append(('CRX', crx_g if crx_g and crx_g != '∅' else '∅', crx_score))
|
|
|
|
# 2. iDRegEx (standalone, langsize-based)
|
|
idr_g, idr_score = _run_idregex(sequences, kmax, N)
|
|
if idr_g:
|
|
results.append(('iDRegEx', idr_g, idr_score))
|
|
|
|
# 3. kOREInference (Algorithm 4 with MDL scoring)
|
|
kore_g, kore_score = _run_kore(sequences, kmax, N)
|
|
if kore_g:
|
|
results.append(('kOREInference', kore_g, kore_score))
|
|
|
|
results = [r for r in results if r[1] and r[1] != '∅']
|
|
if not results:
|
|
base = {
|
|
'best': None,
|
|
'all': [],
|
|
'why': "No algorithm produced a non-empty grammar.",
|
|
}
|
|
if min_coverage < 1.0:
|
|
core_g, core_seqs, outliers, _ = _find_core(sequences, min_coverage)
|
|
base['core'] = {
|
|
'grammar': core_g,
|
|
'coverage': round(len(core_seqs) / max(len(sequences), 1), 2) if sequences else 0,
|
|
'outliers': outliers,
|
|
}
|
|
return base
|
|
|
|
results.sort(key=lambda x: x[2])
|
|
best = results[0]
|
|
all_results = [
|
|
{'algorithm': a, 'grammar': g, 'mdl_score': round(s, 2)}
|
|
for a, g, s in results
|
|
]
|
|
|
|
active = {r[0] for r in results}
|
|
|
|
why_parts = []
|
|
if len(results) == 1:
|
|
why_parts.append(f"Only {results[0][0]} produced a result.")
|
|
else:
|
|
scores_str = ', '.join(f"{r[0]}={r[2]:.1f}" for r in results)
|
|
why_parts.append(f"Scores: {scores_str}.")
|
|
|
|
match_strs = []
|
|
for r_algo, r_grammar, _ in results:
|
|
if r_grammar and r_grammar != '∅':
|
|
m = sum(1 for s in sequences if _matches(r_grammar, s))
|
|
match_strs.append(f"{r_algo}={m}/{len(sequences)}")
|
|
if match_strs:
|
|
why_parts.append(f"Match rates: {', '.join(match_strs)}.")
|
|
|
|
why_parts.append(f"{best[0]} selected (MDL score {best[2]:.1f}).")
|
|
|
|
result = {
|
|
'best': {
|
|
'algorithm': best[0],
|
|
'grammar': best[1],
|
|
'mdl_score': round(best[2], 2),
|
|
},
|
|
'all': all_results,
|
|
'why': ' '.join(why_parts),
|
|
}
|
|
|
|
# Core analysis when min_coverage < 1.0
|
|
if min_coverage < 1.0:
|
|
core_g, core_seqs, outliers, _ = _find_core(sequences, min_coverage)
|
|
result['core'] = {
|
|
'grammar': core_g,
|
|
'coverage': round(len(core_seqs) / max(len(sequences), 1), 2) if sequences else 0,
|
|
'outlier_count': len(outliers),
|
|
'outliers': outliers,
|
|
}
|
|
result['why'] += f' Core CRX ({min_coverage:.0%} coverage, {len(outliers)} outliers): {core_g}'
|
|
|
|
return result
|