"""Ensemble grammar inference — run multiple algorithms, pick best by MDL scoring.""" import re from .crx import CRX from .idregex import idregex from .kore import kOREInference from .expr import alphabet from .mdl import model_cost, mdl_score def _parse_parts(expr): """Parse expression into a list of tokens for matching. Each token: (type, value, quantifier) type: 'symbol' | 'disj' | 'concat' | 'empty' quantifier: '' | '?' | '+' | '+?' """ if not expr or expr == '∅': return [('empty', '', '')] if expr == 'ε': return [('empty', '', '+?')] # 1. Check if it's a concatenation (split outermost by '.') # Must check BEFORE stripping trailing quantifier, because # quantifiers belong to individual parts (e.g., a?.b+) concat_parts = _split_outer(expr.strip(), '.') if len(concat_parts) > 1: children = [] for p in concat_parts: children.extend(_parse_parts(p.strip())) return [('concat', children, '')] # 2. Now handle quantifier suffix on this single part quantifier = '' if expr.endswith('+?'): quantifier = '+?' expr = expr[:-2] elif expr.endswith('*'): quantifier = '*' expr = expr[:-1] elif expr.endswith('?'): quantifier = '?' expr = expr[:-1] elif expr.endswith('+'): quantifier = '+' expr = expr[:-1] # 3. Disjunction group: (a+b+c) for CRX or (a|b|c) for iDRegEx if expr.startswith('(') and expr.endswith(')'): inner = expr[1:-1] # Try CRX-style (+) first, then iDRegEx-style (|) disj_parts = _split_outer(inner, '+') if len(disj_parts) <= 1: disj_parts = _split_outer(inner, '|') if len(disj_parts) > 1: children = [] for p in disj_parts: p = p.strip() # Parse as a flat symbol (don't split dots — they're part of # the symbol name, e.g. "community.docker.docker_image") children.append(_parse_flat_symbol(p)) return [('disj', children, quantifier)] # Single element inside parens: treat as flat symbol return [_parse_flat_symbol(inner)] # 4. Single symbol if expr and expr not in ('∅', 'ε'): return [('symbol', expr, quantifier)] return [] def _parse_flat_symbol(s): """Parse a single symbol with optional quantifier, no dot splitting. Unlike _parse_parts, this treats dots as part of the symbol name (e.g. 'community.docker.docker_image' stays as one symbol). """ s = s.strip() quantifier = '' if s.endswith('+?'): quantifier = '+?' s = s[:-2] elif s.endswith('*'): quantifier = '*' s = s[:-1] elif s.endswith('?'): quantifier = '?' s = s[:-1] elif s.endswith('+'): quantifier = '+' s = s[:-1] if s and s not in ('∅', 'ε'): return ('symbol', s, quantifier) return ('empty', '', quantifier) def _split_outer(s, sep): """Split on `sep` at the top level (not inside parentheses).""" depth = 0 parts = [] cur = [] for ch in s: if ch == '(': depth += 1 cur.append(ch) elif ch == ')': depth -= 1 cur.append(ch) elif ch == sep and depth == 0: parts.append(''.join(cur)) cur = [] else: cur.append(ch) parts.append(''.join(cur)) return parts def _match_possible(token, seq, pos): """Return all possible end positions after matching this token starting at pos.""" ttype, tval, tquant = token positions = [] if ttype == 'empty': positions.append(pos) elif ttype == 'symbol': if tquant in ('', '?'): if pos < len(seq) and seq[pos] == tval: positions.append(pos + 1) if tquant == '?': positions.append(pos) elif tquant in ('+?', '*'): positions.append(pos) cnt = pos while cnt < len(seq) and seq[cnt] == tval: cnt += 1 positions.append(cnt) elif tquant == '+': if pos < len(seq) and seq[pos] == tval: cnt = pos + 1 positions.append(cnt) while cnt < len(seq) and seq[cnt] == tval: cnt += 1 positions.append(cnt) elif ttype == 'disj': if tquant in ('', '?'): for child in tval: for ep in _match_possible(child, seq, pos): positions.append(ep) if tquant == '?': positions.append(pos) elif tquant in ('+?', '*'): positions.append(pos) for child in tval: for ep in _match_possible(child, seq, pos): if ep > pos: positions.append(ep) # After consuming one, recurse to try more for ep2 in _match_possible(token, seq, ep): if ep2 > ep: positions.append(ep2) elif tquant == '+': for child in tval: for ep in _match_possible(child, seq, pos): if ep > pos: positions.append(ep) for ep2 in _match_possible(token, seq, ep): if ep2 > ep: positions.append(ep2) elif ttype == 'concat': # Match all children sequentially def _match_seq(children, start): cur = [start] for child in children: next_cur = [] for p in cur: next_cur.extend(_match_possible(child, seq, p)) cur = next_cur if not cur: break return cur if tquant in ('', '?'): positions.extend(_match_seq(tval, pos)) if tquant == '?': positions.append(pos) elif tquant in ('+?', '*'): positions.append(pos) inner_end = _match_seq(tval, pos) for ep in inner_end: if ep > pos: positions.append(ep) for ep2 in _match_possible(token, seq, ep): if ep2 > ep: positions.append(ep2) elif tquant == '+': inner_end = _match_seq(tval, pos) for ep in inner_end: if ep > pos: positions.append(ep) for ep2 in _match_possible(token, seq, ep): if ep2 > ep: positions.append(ep2) return positions def _match_tokens(tokens, seq, pos=0): """Try to match tokens against seq starting at pos. Returns max position or None.""" cur = [pos] for token in tokens: next_cur = [] for p in cur: next_cur.extend(_match_possible(token, seq, p)) cur = next_cur if not cur: return None return max(cur) if cur else pos def _matches(grammar, sequence): """Check if a sequence matches the grammar.""" try: tokens = _parse_parts(grammar.strip()) if not tokens: return False end = _match_tokens(tokens, sequence) if end is None: return False return end == len(sequence) except Exception: return False def _fit_score(grammar, seq): """Score how tightly a sequence fits: 1.0 = perfect match to core, 0.0 = mostly uses optional/repeated parts. Instead of trying to parse the grammar structure (which is fragile), this measures how well seq matches against the grammatical core by comparing its symbol positions to the grammar's 'spine' — the symbols that appear in all sequences. """ if not seq: return 0.0 try: # Strategy: parse grammar tokens, match seq, count what fraction # of seq length is consumed by obligatory (non-?, non-+?) tokens. tokens = _parse_parts(grammar.strip()) if not tokens or tokens[0][0] == 'empty': return 0.0 def _classify_tokens(node): """Return (obligatory_count, optional_count) for this node.""" tt, tv, tq = node if tt == 'symbol': if tq in ('', '+'): return (1, 0) return (0, 1) if tt == 'concat': ob, op = 0, 0 for c in tv: if c[0] == 'empty': continue o1, o2 = _classify_tokens(c) ob += o1 op += o2 return (ob, op) if tt == 'disj': # Any alternative counts as optional return (0, len(tv)) return (0, 0) ob, op = _classify_tokens(tokens[0]) total = ob + op if total == 0: return 0.5 # Match seq and see how many symbols are actually consumed end = _match_tokens(tokens, seq) if end is None or end != len(seq): return 0.0 # Fit = fraction of mandatory symbols / total mandatory+optional # Penalizes sequences that lean heavily on optional parts return max(0.0, 1.0 - (op / total)) except Exception: return 0.0 def _symbol_rarity_score(seq, all_sequences): """Score a sequence by how rare its symbols are across the dataset. 1.0 = all symbols are common, 0.0 = mostly rare symbols. """ from collections import Counter all_syms = Counter() for s in all_sequences: all_syms.update(s) n = len(all_sequences) scores = [] for sym in seq: freq = all_syms.get(sym, 0) / n scores.append(min(freq, 1.0)) return sum(scores) / len(scores) if scores else 0.0 def _find_core(sequences, min_coverage=0.8): """Find the core subset of sequences by iterative CRX + outlier removal. Outlier detection uses symbol rarity: sequences with rare symbols (appearing in few other sequences) are removed first. Returns: (core_grammar, core_sequences, outliers, fit_scores) """ if not sequences or min_coverage >= 1.0: crx_g = CRX().infer(sequences) return crx_g, sequences, [], [] from collections import Counter all_syms = Counter() for s in sequences: all_syms.update(s) n = len(sequences) def _rarity(seq): rare_count = sum(1 for sym in seq if all_syms.get(sym, 0) / n < 0.3) return rare_count / max(len(seq), 1) working = list(sequences) removed_indices = [] crx = CRX() for _ in range(50): if len(working) < 3: break target = max(int(len(sequences) * min_coverage), 1) if len(working) <= target: break # Score by rarity: most rare symbol → worst fit scores = [(i, _rarity(seq)) for i, seq in enumerate(working)] scores.sort(key=lambda x: -x[1]) # most rare first # If all sequences have the same score, stop (no outliers to remove) if len(scores) < 2 or scores[0][1] == scores[-1][1]: break worst_idx = scores[0][0] removed_indices.append(working[worst_idx]) working = [s for i, s in enumerate(working) if i != worst_idx] core_g = crx.infer(working) if working else None return core_g, working, removed_indices, [] def mdl_score_simple(grammar, sequences): """MDL score from the paper: model_cost + Σ log₂(|L(r)| at length len(s)). Lower is better. Uses the paper's definition from Bex et al. model_cost = number of alphabet symbol occurrences in the expression. data_cost = Σ log₂(|L(r)|) — penalizes overly general grammars. """ return mdl_score(grammar, sequences) def _run_idregex(sequences, kmax, N): """Run standalone iDRegEx, return (grammar, score) or (None, inf).""" g = idregex(sequences, kmax=kmax, N=N) if g and g != '∅': return g, mdl_score_simple(g, sequences) return None, float('inf') def _run_kore(sequences, kmax, N): """Run kOREInference (Algorithm 4 with MDL), return (grammar, score) or (None, inf).""" kore = kOREInference(k_max=kmax, N=N) result = kore.infer(sequences) if result: _, expr, _ = result return expr, mdl_score_simple(expr, sequences) return None, float('inf') _ALGO_NAMES = { 'crx': 'CRX', 'idregex': 'iDRegEx', 'koreinference': 'kOREInference', } _ALGORITHMS = { 'crx': lambda s, k, n: (CRX().infer(s), mdl_score_simple(CRX().infer(s), s)), 'idregex': _run_idregex, 'koreinference': _run_kore, } def infer_ensemble(sequences, kmax=2, N=3, prefer=None, min_coverage=1.0): """Run all applicable algorithms and return the best by MDL score. Args: sequences: List of sequences, each a list of strings. kmax: Maximum k for k-ORE inference (iDRegEx, kOREInference). N: Number of random trials for k-ORE inference. prefer: Optional — 'crx', 'idregex', or 'koreinference' to skip ensemble and return only that algorithm's result. min_coverage: When < 1.0, also runs CRX on the tightest core subset of sequences. Outliers (worst-fitting) are iteratively removed until at least this fraction remains. The core grammar and outlier list are included in the response. Returns: dict with keys: best: {algorithm, grammar, mdl_score} all: [{algorithm, grammar, mdl_score}, ...] why: str explaining the choice core: (optional) {grammar, coverage, outliers} — only when min_coverage < 1.0 """ if prefer and prefer.lower() in _ALGORITHMS: key = prefer.lower() fn = _ALGORITHMS[key] algo_name = _ALGO_NAMES.get(key, key) g, score = fn(sequences, kmax, N) if g and g != '∅': return { 'best': {'algorithm': algo_name, 'grammar': g, 'mdl_score': round(score, 2)}, 'all': [{'algorithm': algo_name, 'grammar': g, 'mdl_score': round(score, 2)}], 'why': f"Requested {algo_name} only.", } return { 'best': None, 'all': [], 'why': f"{algo_name} returned ∅ (no grammar found).", } results = [] # 1. CRX (always fast, always produces a result) crx_g = CRX().infer(sequences) crx_score = mdl_score_simple(crx_g, sequences) if crx_g and crx_g != '∅' else float('inf') results.append(('CRX', crx_g if crx_g and crx_g != '∅' else '∅', crx_score)) # 2. iDRegEx (standalone, langsize-based) idr_g, idr_score = _run_idregex(sequences, kmax, N) if idr_g: results.append(('iDRegEx', idr_g, idr_score)) # 3. kOREInference (Algorithm 4 with MDL scoring) kore_g, kore_score = _run_kore(sequences, kmax, N) if kore_g: results.append(('kOREInference', kore_g, kore_score)) results = [r for r in results if r[1] and r[1] != '∅'] if not results: base = { 'best': None, 'all': [], 'why': "No algorithm produced a non-empty grammar.", } if min_coverage < 1.0: core_g, core_seqs, outliers, _ = _find_core(sequences, min_coverage) base['core'] = { 'grammar': core_g, 'coverage': round(len(core_seqs) / max(len(sequences), 1), 2) if sequences else 0, 'outliers': outliers, } return base results.sort(key=lambda x: x[2]) best = results[0] all_results = [ {'algorithm': a, 'grammar': g, 'mdl_score': round(s, 2)} for a, g, s in results ] active = {r[0] for r in results} why_parts = [] if len(results) == 1: why_parts.append(f"Only {results[0][0]} produced a result.") else: scores_str = ', '.join(f"{r[0]}={r[2]:.1f}" for r in results) why_parts.append(f"Scores: {scores_str}.") match_strs = [] for r_algo, r_grammar, _ in results: if r_grammar and r_grammar != '∅': m = sum(1 for s in sequences if _matches(r_grammar, s)) match_strs.append(f"{r_algo}={m}/{len(sequences)}") if match_strs: why_parts.append(f"Match rates: {', '.join(match_strs)}.") why_parts.append(f"{best[0]} selected (MDL score {best[2]:.1f}).") result = { 'best': { 'algorithm': best[0], 'grammar': best[1], 'mdl_score': round(best[2], 2), }, 'all': all_results, 'why': ' '.join(why_parts), } # Core analysis when min_coverage < 1.0 if min_coverage < 1.0: core_g, core_seqs, outliers, _ = _find_core(sequences, min_coverage) result['core'] = { 'grammar': core_g, 'coverage': round(len(core_seqs) / max(len(sequences), 1), 2) if sequences else 0, 'outlier_count': len(outliers), 'outliers': outliers, } result['why'] += f' Core CRX ({min_coverage:.0%} coverage, {len(outliers)} outliers): {core_g}' return result