"""Ensemble grammar inference — run multiple algorithms, pick best by MDL scoring.""" import re from .crx import CRX from .idregex import idregex from .expr import alphabet from .mdl import model_cost, mdl_score def _parse_parts(expr): """Parse expression into a list of tokens for matching. Each token: (type, value, quantifier) type: 'symbol' | 'disj' | 'concat' | 'empty' quantifier: '' | '?' | '+' | '+?' """ if not expr or expr == '∅': return [('empty', '', '')] if expr == 'ε': return [('empty', '', '+?')] # 1. Check if it's a concatenation (split outermost by '.') # Must check BEFORE stripping trailing quantifier, because # quantifiers belong to individual parts (e.g., a?.b+) concat_parts = _split_outer(expr.strip(), '.') if len(concat_parts) > 1: children = [] for p in concat_parts: children.extend(_parse_parts(p.strip())) return [('concat', children, '')] # 2. Now handle quantifier suffix on this single part quantifier = '' if expr.endswith('+?'): quantifier = '+?' expr = expr[:-2] elif expr.endswith('*'): quantifier = '*' expr = expr[:-1] elif expr.endswith('?'): quantifier = '?' expr = expr[:-1] elif expr.endswith('+'): quantifier = '+' expr = expr[:-1] # 3. Disjunction group: (a+b+c) for CRX or (a|b|c) for iDRegEx if expr.startswith('(') and expr.endswith(')'): inner = expr[1:-1] # Try CRX-style (+) first, then iDRegEx-style (|) disj_parts = _split_outer(inner, '+') if len(disj_parts) <= 1: disj_parts = _split_outer(inner, '|') if len(disj_parts) > 1: children = [] for p in disj_parts: p = p.strip() # Parse as a flat symbol (don't split dots — they're part of # the symbol name, e.g. "community.docker.docker_image") children.append(_parse_flat_symbol(p)) return [('disj', children, quantifier)] # Single element inside parens: treat as flat symbol return [_parse_flat_symbol(inner)] # 4. Single symbol if expr and expr not in ('∅', 'ε'): return [('symbol', expr, quantifier)] return [] def _parse_flat_symbol(s): """Parse a single symbol with optional quantifier, no dot splitting. Unlike _parse_parts, this treats dots as part of the symbol name (e.g. 'community.docker.docker_image' stays as one symbol). """ s = s.strip() quantifier = '' if s.endswith('+?'): quantifier = '+?' s = s[:-2] elif s.endswith('*'): quantifier = '*' s = s[:-1] elif s.endswith('?'): quantifier = '?' s = s[:-1] elif s.endswith('+'): quantifier = '+' s = s[:-1] if s and s not in ('∅', 'ε'): return ('symbol', s, quantifier) return ('empty', '', quantifier) def _split_outer(s, sep): """Split on `sep` at the top level (not inside parentheses).""" depth = 0 parts = [] cur = [] for ch in s: if ch == '(': depth += 1 cur.append(ch) elif ch == ')': depth -= 1 cur.append(ch) elif ch == sep and depth == 0: parts.append(''.join(cur)) cur = [] else: cur.append(ch) parts.append(''.join(cur)) return parts def _match_possible(token, seq, pos): """Return all possible end positions after matching this token starting at pos.""" ttype, tval, tquant = token positions = [] if ttype == 'empty': positions.append(pos) elif ttype == 'symbol': if tquant in ('', '?'): if pos < len(seq) and seq[pos] == tval: positions.append(pos + 1) if tquant == '?': positions.append(pos) elif tquant in ('+?', '*'): positions.append(pos) cnt = pos while cnt < len(seq) and seq[cnt] == tval: cnt += 1 positions.append(cnt) elif tquant == '+': if pos < len(seq) and seq[pos] == tval: cnt = pos + 1 positions.append(cnt) while cnt < len(seq) and seq[cnt] == tval: cnt += 1 positions.append(cnt) elif ttype == 'disj': if tquant in ('', '?'): for child in tval: for ep in _match_possible(child, seq, pos): positions.append(ep) if tquant == '?': positions.append(pos) elif tquant in ('+?', '*'): positions.append(pos) for child in tval: for ep in _match_possible(child, seq, pos): if ep > pos: positions.append(ep) # After consuming one, recurse to try more for ep2 in _match_possible(token, seq, ep): if ep2 > ep: positions.append(ep2) elif tquant == '+': for child in tval: for ep in _match_possible(child, seq, pos): if ep > pos: positions.append(ep) for ep2 in _match_possible(token, seq, ep): if ep2 > ep: positions.append(ep2) elif ttype == 'concat': # Match all children sequentially def _match_seq(children, start): cur = [start] for child in children: next_cur = [] for p in cur: next_cur.extend(_match_possible(child, seq, p)) cur = next_cur if not cur: break return cur if tquant in ('', '?'): positions.extend(_match_seq(tval, pos)) if tquant == '?': positions.append(pos) elif tquant in ('+?', '*'): positions.append(pos) inner_end = _match_seq(tval, pos) for ep in inner_end: if ep > pos: positions.append(ep) for ep2 in _match_possible(token, seq, ep): if ep2 > ep: positions.append(ep2) elif tquant == '+': inner_end = _match_seq(tval, pos) for ep in inner_end: if ep > pos: positions.append(ep) for ep2 in _match_possible(token, seq, ep): if ep2 > ep: positions.append(ep2) return positions def _match_tokens(tokens, seq, pos=0): """Try to match tokens against seq starting at pos. Returns max position or None.""" cur = [pos] for token in tokens: next_cur = [] for p in cur: next_cur.extend(_match_possible(token, seq, p)) cur = next_cur if not cur: return None return max(cur) if cur else pos def _matches(grammar, sequence): """Check if a sequence matches the grammar.""" try: tokens = _parse_parts(grammar.strip()) if not tokens: return False end = _match_tokens(tokens, sequence) if end is None: return False return end == len(sequence) except Exception: return False def mdl_score_simple(grammar, sequences): """MDL score from the paper: model_cost + Σ log₂(|L(r)| at length len(s)). Lower is better. Uses the paper's definition from Bex et al. model_cost = number of alphabet symbol occurrences in the expression. data_cost = Σ log₂(|L(r)|) — penalizes overly general grammars. """ return mdl_score(grammar, sequences) def infer_ensemble(sequences, kmax=2, N=3, prefer=None): """Run all applicable algorithms and return the best by MDL score. Args: sequences: List of sequences, each a list of strings. kmax: Maximum k for iDRegEx k-ORE inference. N: Number of EM iterations for iDRegEx. prefer: Optional — 'crx' or 'idregex' to skip ensemble and return only that algorithm's result. Returns: dict with keys: best: {algorithm, grammar, mdl_score} all: [{algorithm, grammar, mdl_score}, ...] why: str explaining the choice """ results = [] if prefer and prefer.lower() == 'idregex': idr_g = idregex(sequences, kmax=kmax, N=N) idr_score = mdl_score_simple(idr_g, sequences) if idr_g and idr_g != '∅' else float('inf') if idr_g and idr_g != '∅': results.append(('iDRegEx', idr_g, idr_score)) if not results: return { 'best': None, 'all': [], 'why': "iDRegEx returned ∅ (no common core found).", } why = "Requested iDRegEx only." return { 'best': { 'algorithm': 'iDRegEx', 'grammar': results[0][1], 'mdl_score': round(results[0][2], 2), }, 'all': [{'algorithm': 'iDRegEx', 'grammar': results[0][1], 'mdl_score': round(results[0][2], 2)}], 'why': why, } crx_g = CRX().infer(sequences) crx_score = mdl_score_simple(crx_g, sequences) results.append(('CRX', crx_g, crx_score)) if prefer and prefer.lower() == 'crx': return { 'best': { 'algorithm': 'CRX', 'grammar': crx_g, 'mdl_score': round(crx_score, 2), }, 'all': [{'algorithm': 'CRX', 'grammar': crx_g, 'mdl_score': round(crx_score, 2)}], 'why': "Requested CRX only.", } idr_g = idregex(sequences, kmax=kmax, N=N) if idr_g and idr_g != '∅': idr_score = mdl_score_simple(idr_g, sequences) results.append(('iDRegEx', idr_g, idr_score)) results.sort(key=lambda x: x[2]) best = results[0] all_results = [ {'algorithm': a, 'grammar': g, 'mdl_score': round(s, 2)} for a, g, s in results ] crx_match = sum(1 for s in sequences if _matches(crx_g, s)) idr_match = sum(1 for s in sequences if _matches(idr_g, s)) if len(results) > 1 else 0 why_parts = [] if len(results) == 1: why_parts.append(f"Only CRX produced a result (iDRegEx returned ∅).") else: why_parts.append( f"{results[0][0]} (score {results[0][2]:.1f}) vs {results[1][0]} (score {results[1][2]:.1f})." ) if crx_match == idr_match == len(sequences): why_parts.append("Both grammars match all sequences.") why_parts.append( f"{results[0][0]} wins because it is more compact " f"(lower model cost) while matching all data." ) elif crx_match != idr_match: why_parts.append( f"CRX matches {crx_match}/{len(sequences)} sequences, " f"iDRegEx matches {idr_match}/{len(sequences)}." ) why_parts.append( f"{best[0]} selected (MDL score {best[2]:.1f})." ) return { 'best': { 'algorithm': best[0], 'grammar': best[1], 'mdl_score': round(best[2], 2), }, 'all': all_results, 'why': ' '.join(why_parts), }