diff --git a/bex/ensemble.py b/bex/ensemble.py index 74eddde..93e8cd3 100644 --- a/bex/ensemble.py +++ b/bex/ensemble.py @@ -234,6 +234,129 @@ def _matches(grammar, sequence): return False +def _fit_score(grammar, seq): + """Score how tightly a sequence fits: 1.0 = perfect match to core, + 0.0 = mostly uses optional/repeated parts. + + Instead of trying to parse the grammar structure (which is fragile), + this measures how well seq matches against the grammatical core by + comparing its symbol positions to the grammar's 'spine' — the symbols + that appear in all sequences. + """ + if not seq: + return 0.0 + try: + # Strategy: parse grammar tokens, match seq, count what fraction + # of seq length is consumed by obligatory (non-?, non-+?) tokens. + tokens = _parse_parts(grammar.strip()) + if not tokens or tokens[0][0] == 'empty': + return 0.0 + + def _classify_tokens(node): + """Return (obligatory_count, optional_count) for this node.""" + tt, tv, tq = node + if tt == 'symbol': + if tq in ('', '+'): + return (1, 0) + return (0, 1) + if tt == 'concat': + ob, op = 0, 0 + for c in tv: + if c[0] == 'empty': + continue + o1, o2 = _classify_tokens(c) + ob += o1 + op += o2 + return (ob, op) + if tt == 'disj': + # Any alternative counts as optional + return (0, len(tv)) + return (0, 0) + + ob, op = _classify_tokens(tokens[0]) + total = ob + op + if total == 0: + return 0.5 + + # Match seq and see how many symbols are actually consumed + end = _match_tokens(tokens, seq) + if end is None or end != len(seq): + return 0.0 + + # Fit = fraction of mandatory symbols / total mandatory+optional + # Penalizes sequences that lean heavily on optional parts + return max(0.0, 1.0 - (op / total)) + except Exception: + return 0.0 + + +def _symbol_rarity_score(seq, all_sequences): + """Score a sequence by how rare its symbols are across the dataset. + 1.0 = all symbols are common, 0.0 = mostly rare symbols. + """ + from collections import Counter + all_syms = Counter() + for s in all_sequences: + all_syms.update(s) + n = len(all_sequences) + scores = [] + for sym in seq: + freq = all_syms.get(sym, 0) / n + scores.append(min(freq, 1.0)) + return sum(scores) / len(scores) if scores else 0.0 + + +def _find_core(sequences, min_coverage=0.8): + """Find the core subset of sequences by iterative CRX + outlier removal. + + Outlier detection uses symbol rarity: sequences with rare symbols + (appearing in few other sequences) are removed first. + + Returns: + (core_grammar, core_sequences, outliers, fit_scores) + """ + if not sequences or min_coverage >= 1.0: + crx_g = CRX().infer(sequences) + return crx_g, sequences, [], [] + + from collections import Counter + all_syms = Counter() + for s in sequences: + all_syms.update(s) + n = len(sequences) + + def _rarity(seq): + rare_count = sum(1 for sym in seq if all_syms.get(sym, 0) / n < 0.3) + return rare_count / max(len(seq), 1) + + working = list(sequences) + removed_indices = [] + crx = CRX() + + for _ in range(50): + if len(working) < 3: + break + + target = max(int(len(sequences) * min_coverage), 1) + if len(working) <= target: + break + + # Score by rarity: most rare symbol → worst fit + scores = [(i, _rarity(seq)) for i, seq in enumerate(working)] + scores.sort(key=lambda x: -x[1]) # most rare first + + # If all sequences have the same score, stop (no outliers to remove) + if len(scores) < 2 or scores[0][1] == scores[-1][1]: + break + + worst_idx = scores[0][0] + removed_indices.append(working[worst_idx]) + working = [s for i, s in enumerate(working) if i != worst_idx] + + core_g = crx.infer(working) if working else None + return core_g, working, removed_indices, [] + + def mdl_score_simple(grammar, sequences): """MDL score from the paper: model_cost + Σ log₂(|L(r)| at length len(s)). @@ -276,7 +399,7 @@ _ALGORITHMS = { } -def infer_ensemble(sequences, kmax=2, N=3, prefer=None): +def infer_ensemble(sequences, kmax=2, N=3, prefer=None, min_coverage=1.0): """Run all applicable algorithms and return the best by MDL score. Args: @@ -285,12 +408,18 @@ def infer_ensemble(sequences, kmax=2, N=3, prefer=None): N: Number of random trials for k-ORE inference. prefer: Optional — 'crx', 'idregex', or 'koreinference' to skip ensemble and return only that algorithm's result. + min_coverage: When < 1.0, also runs CRX on the tightest core subset + of sequences. Outliers (worst-fitting) are iteratively + removed until at least this fraction remains. The core + grammar and outlier list are included in the response. Returns: dict with keys: best: {algorithm, grammar, mdl_score} all: [{algorithm, grammar, mdl_score}, ...] why: str explaining the choice + core: (optional) {grammar, coverage, outliers} — only when + min_coverage < 1.0 """ if prefer and prefer.lower() in _ALGORITHMS: key = prefer.lower() @@ -328,11 +457,19 @@ def infer_ensemble(sequences, kmax=2, N=3, prefer=None): results = [r for r in results if r[1] and r[1] != '∅'] if not results: - return { + base = { 'best': None, 'all': [], 'why': "No algorithm produced a non-empty grammar.", } + if min_coverage < 1.0: + core_g, core_seqs, outliers, _ = _find_core(sequences, min_coverage) + base['core'] = { + 'grammar': core_g, + 'coverage': round(len(core_seqs) / max(len(sequences), 1), 2) if sequences else 0, + 'outliers': outliers, + } + return base results.sort(key=lambda x: x[2]) best = results[0] @@ -360,7 +497,7 @@ def infer_ensemble(sequences, kmax=2, N=3, prefer=None): why_parts.append(f"{best[0]} selected (MDL score {best[2]:.1f}).") - return { + result = { 'best': { 'algorithm': best[0], 'grammar': best[1], @@ -369,3 +506,16 @@ def infer_ensemble(sequences, kmax=2, N=3, prefer=None): 'all': all_results, 'why': ' '.join(why_parts), } + + # Core analysis when min_coverage < 1.0 + if min_coverage < 1.0: + core_g, core_seqs, outliers, _ = _find_core(sequences, min_coverage) + result['core'] = { + 'grammar': core_g, + 'coverage': round(len(core_seqs) / max(len(sequences), 1), 2) if sequences else 0, + 'outlier_count': len(outliers), + 'outliers': outliers, + } + result['why'] += f' Core CRX ({min_coverage:.0%} coverage, {len(outliers)} outliers): {core_g}' + + return result diff --git a/tests/test_ensemble.py b/tests/test_ensemble.py index 2d4205c..db15627 100644 --- a/tests/test_ensemble.py +++ b/tests/test_ensemble.py @@ -164,6 +164,67 @@ def test_ensemble_crx_always_present(): assert len(crx_results) == 1 +# ── min_coverage / core analysis tests ── + +def test_core_not_included_when_coverage_1(): + seqs = [['a', 'b'], ['a', 'b', 'c']] + result = infer_ensemble(seqs, min_coverage=1.0) + assert 'core' not in result + + +def test_core_included_when_coverage_lt_1(): + seqs = [['a', 'b'], ['a', 'b', 'c']] + result = infer_ensemble(seqs, min_coverage=0.8) + assert 'core' in result + assert 'grammar' in result['core'] + assert 'coverage' in result['core'] + assert 'outliers' in result['core'] + assert 'outlier_count' in result['core'] + + +def test_core_outlier_detection(): + seqs = [ + ['fail', 'package', 'file', 'service'], + ['fail', 'package', 'file', 'service'], + ['fail', 'package', 'file', 'service', 'npm'], + ['fail', 'package', 'file', 'service', 'npm', 'pip'], + ] + result = infer_ensemble(seqs, min_coverage=0.7) + assert 'core' in result + c = result['core'] + assert c['outlier_count'] >= 1 + assert 'npm' in c['grammar'] or 'service' in c['grammar'] + + +def test_core_all_identical(): + seqs = [['a', 'b', 'c']] * 10 + result = infer_ensemble(seqs, min_coverage=0.8) + assert 'core' in result + assert result['core']['outlier_count'] == 0 + assert 'a' in result['core']['grammar'] + + +def test_core_coverage_ratio(): + seqs = [ + ['a', 'b', 'c'], + ['a', 'b', 'c'], + ['a', 'b', 'c', 'd'], + ['a', 'b', 'c', 'd', 'e'], + ] + result = infer_ensemble(seqs, min_coverage=0.7) + if 'core' in result: + c = result['core'] + assert c['outlier_count'] >= 1 + assert len(c['outliers']) >= 1 + assert c['coverage'] >= 0.5 + + +def test_core_empty_sequences(): + result = infer_ensemble([], min_coverage=0.8) + assert 'core' in result + assert result['core']['grammar'] is not None + + def run_all(): tests = [ test_ensemble_returns_dict,