feat: core+outlier analysis via min_coverage parameter, 6 new tests
This commit is contained in:
parent
edd6d9d4dd
commit
9045769d57
2 changed files with 214 additions and 3 deletions
156
bex/ensemble.py
156
bex/ensemble.py
|
|
@ -234,6 +234,129 @@ def _matches(grammar, sequence):
|
|||
return False
|
||||
|
||||
|
||||
def _fit_score(grammar, seq):
|
||||
"""Score how tightly a sequence fits: 1.0 = perfect match to core,
|
||||
0.0 = mostly uses optional/repeated parts.
|
||||
|
||||
Instead of trying to parse the grammar structure (which is fragile),
|
||||
this measures how well seq matches against the grammatical core by
|
||||
comparing its symbol positions to the grammar's 'spine' — the symbols
|
||||
that appear in all sequences.
|
||||
"""
|
||||
if not seq:
|
||||
return 0.0
|
||||
try:
|
||||
# Strategy: parse grammar tokens, match seq, count what fraction
|
||||
# of seq length is consumed by obligatory (non-?, non-+?) tokens.
|
||||
tokens = _parse_parts(grammar.strip())
|
||||
if not tokens or tokens[0][0] == 'empty':
|
||||
return 0.0
|
||||
|
||||
def _classify_tokens(node):
|
||||
"""Return (obligatory_count, optional_count) for this node."""
|
||||
tt, tv, tq = node
|
||||
if tt == 'symbol':
|
||||
if tq in ('', '+'):
|
||||
return (1, 0)
|
||||
return (0, 1)
|
||||
if tt == 'concat':
|
||||
ob, op = 0, 0
|
||||
for c in tv:
|
||||
if c[0] == 'empty':
|
||||
continue
|
||||
o1, o2 = _classify_tokens(c)
|
||||
ob += o1
|
||||
op += o2
|
||||
return (ob, op)
|
||||
if tt == 'disj':
|
||||
# Any alternative counts as optional
|
||||
return (0, len(tv))
|
||||
return (0, 0)
|
||||
|
||||
ob, op = _classify_tokens(tokens[0])
|
||||
total = ob + op
|
||||
if total == 0:
|
||||
return 0.5
|
||||
|
||||
# Match seq and see how many symbols are actually consumed
|
||||
end = _match_tokens(tokens, seq)
|
||||
if end is None or end != len(seq):
|
||||
return 0.0
|
||||
|
||||
# Fit = fraction of mandatory symbols / total mandatory+optional
|
||||
# Penalizes sequences that lean heavily on optional parts
|
||||
return max(0.0, 1.0 - (op / total))
|
||||
except Exception:
|
||||
return 0.0
|
||||
|
||||
|
||||
def _symbol_rarity_score(seq, all_sequences):
|
||||
"""Score a sequence by how rare its symbols are across the dataset.
|
||||
1.0 = all symbols are common, 0.0 = mostly rare symbols.
|
||||
"""
|
||||
from collections import Counter
|
||||
all_syms = Counter()
|
||||
for s in all_sequences:
|
||||
all_syms.update(s)
|
||||
n = len(all_sequences)
|
||||
scores = []
|
||||
for sym in seq:
|
||||
freq = all_syms.get(sym, 0) / n
|
||||
scores.append(min(freq, 1.0))
|
||||
return sum(scores) / len(scores) if scores else 0.0
|
||||
|
||||
|
||||
def _find_core(sequences, min_coverage=0.8):
|
||||
"""Find the core subset of sequences by iterative CRX + outlier removal.
|
||||
|
||||
Outlier detection uses symbol rarity: sequences with rare symbols
|
||||
(appearing in few other sequences) are removed first.
|
||||
|
||||
Returns:
|
||||
(core_grammar, core_sequences, outliers, fit_scores)
|
||||
"""
|
||||
if not sequences or min_coverage >= 1.0:
|
||||
crx_g = CRX().infer(sequences)
|
||||
return crx_g, sequences, [], []
|
||||
|
||||
from collections import Counter
|
||||
all_syms = Counter()
|
||||
for s in sequences:
|
||||
all_syms.update(s)
|
||||
n = len(sequences)
|
||||
|
||||
def _rarity(seq):
|
||||
rare_count = sum(1 for sym in seq if all_syms.get(sym, 0) / n < 0.3)
|
||||
return rare_count / max(len(seq), 1)
|
||||
|
||||
working = list(sequences)
|
||||
removed_indices = []
|
||||
crx = CRX()
|
||||
|
||||
for _ in range(50):
|
||||
if len(working) < 3:
|
||||
break
|
||||
|
||||
target = max(int(len(sequences) * min_coverage), 1)
|
||||
if len(working) <= target:
|
||||
break
|
||||
|
||||
# Score by rarity: most rare symbol → worst fit
|
||||
scores = [(i, _rarity(seq)) for i, seq in enumerate(working)]
|
||||
scores.sort(key=lambda x: -x[1]) # most rare first
|
||||
|
||||
# If all sequences have the same score, stop (no outliers to remove)
|
||||
if len(scores) < 2 or scores[0][1] == scores[-1][1]:
|
||||
break
|
||||
|
||||
worst_idx = scores[0][0]
|
||||
removed_indices.append(working[worst_idx])
|
||||
working = [s for i, s in enumerate(working) if i != worst_idx]
|
||||
|
||||
core_g = crx.infer(working) if working else None
|
||||
return core_g, working, removed_indices, []
|
||||
|
||||
|
||||
def mdl_score_simple(grammar, sequences):
|
||||
"""MDL score from the paper: model_cost + Σ log₂(|L(r)| at length len(s)).
|
||||
|
||||
|
|
@ -276,7 +399,7 @@ _ALGORITHMS = {
|
|||
}
|
||||
|
||||
|
||||
def infer_ensemble(sequences, kmax=2, N=3, prefer=None):
|
||||
def infer_ensemble(sequences, kmax=2, N=3, prefer=None, min_coverage=1.0):
|
||||
"""Run all applicable algorithms and return the best by MDL score.
|
||||
|
||||
Args:
|
||||
|
|
@ -285,12 +408,18 @@ def infer_ensemble(sequences, kmax=2, N=3, prefer=None):
|
|||
N: Number of random trials for k-ORE inference.
|
||||
prefer: Optional — 'crx', 'idregex', or 'koreinference' to skip
|
||||
ensemble and return only that algorithm's result.
|
||||
min_coverage: When < 1.0, also runs CRX on the tightest core subset
|
||||
of sequences. Outliers (worst-fitting) are iteratively
|
||||
removed until at least this fraction remains. The core
|
||||
grammar and outlier list are included in the response.
|
||||
|
||||
Returns:
|
||||
dict with keys:
|
||||
best: {algorithm, grammar, mdl_score}
|
||||
all: [{algorithm, grammar, mdl_score}, ...]
|
||||
why: str explaining the choice
|
||||
core: (optional) {grammar, coverage, outliers} — only when
|
||||
min_coverage < 1.0
|
||||
"""
|
||||
if prefer and prefer.lower() in _ALGORITHMS:
|
||||
key = prefer.lower()
|
||||
|
|
@ -328,11 +457,19 @@ def infer_ensemble(sequences, kmax=2, N=3, prefer=None):
|
|||
|
||||
results = [r for r in results if r[1] and r[1] != '∅']
|
||||
if not results:
|
||||
return {
|
||||
base = {
|
||||
'best': None,
|
||||
'all': [],
|
||||
'why': "No algorithm produced a non-empty grammar.",
|
||||
}
|
||||
if min_coverage < 1.0:
|
||||
core_g, core_seqs, outliers, _ = _find_core(sequences, min_coverage)
|
||||
base['core'] = {
|
||||
'grammar': core_g,
|
||||
'coverage': round(len(core_seqs) / max(len(sequences), 1), 2) if sequences else 0,
|
||||
'outliers': outliers,
|
||||
}
|
||||
return base
|
||||
|
||||
results.sort(key=lambda x: x[2])
|
||||
best = results[0]
|
||||
|
|
@ -360,7 +497,7 @@ def infer_ensemble(sequences, kmax=2, N=3, prefer=None):
|
|||
|
||||
why_parts.append(f"{best[0]} selected (MDL score {best[2]:.1f}).")
|
||||
|
||||
return {
|
||||
result = {
|
||||
'best': {
|
||||
'algorithm': best[0],
|
||||
'grammar': best[1],
|
||||
|
|
@ -369,3 +506,16 @@ def infer_ensemble(sequences, kmax=2, N=3, prefer=None):
|
|||
'all': all_results,
|
||||
'why': ' '.join(why_parts),
|
||||
}
|
||||
|
||||
# Core analysis when min_coverage < 1.0
|
||||
if min_coverage < 1.0:
|
||||
core_g, core_seqs, outliers, _ = _find_core(sequences, min_coverage)
|
||||
result['core'] = {
|
||||
'grammar': core_g,
|
||||
'coverage': round(len(core_seqs) / max(len(sequences), 1), 2) if sequences else 0,
|
||||
'outlier_count': len(outliers),
|
||||
'outliers': outliers,
|
||||
}
|
||||
result['why'] += f' Core CRX ({min_coverage:.0%} coverage, {len(outliers)} outliers): {core_g}'
|
||||
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -164,6 +164,67 @@ def test_ensemble_crx_always_present():
|
|||
assert len(crx_results) == 1
|
||||
|
||||
|
||||
# ── min_coverage / core analysis tests ──
|
||||
|
||||
def test_core_not_included_when_coverage_1():
|
||||
seqs = [['a', 'b'], ['a', 'b', 'c']]
|
||||
result = infer_ensemble(seqs, min_coverage=1.0)
|
||||
assert 'core' not in result
|
||||
|
||||
|
||||
def test_core_included_when_coverage_lt_1():
|
||||
seqs = [['a', 'b'], ['a', 'b', 'c']]
|
||||
result = infer_ensemble(seqs, min_coverage=0.8)
|
||||
assert 'core' in result
|
||||
assert 'grammar' in result['core']
|
||||
assert 'coverage' in result['core']
|
||||
assert 'outliers' in result['core']
|
||||
assert 'outlier_count' in result['core']
|
||||
|
||||
|
||||
def test_core_outlier_detection():
|
||||
seqs = [
|
||||
['fail', 'package', 'file', 'service'],
|
||||
['fail', 'package', 'file', 'service'],
|
||||
['fail', 'package', 'file', 'service', 'npm'],
|
||||
['fail', 'package', 'file', 'service', 'npm', 'pip'],
|
||||
]
|
||||
result = infer_ensemble(seqs, min_coverage=0.7)
|
||||
assert 'core' in result
|
||||
c = result['core']
|
||||
assert c['outlier_count'] >= 1
|
||||
assert 'npm' in c['grammar'] or 'service' in c['grammar']
|
||||
|
||||
|
||||
def test_core_all_identical():
|
||||
seqs = [['a', 'b', 'c']] * 10
|
||||
result = infer_ensemble(seqs, min_coverage=0.8)
|
||||
assert 'core' in result
|
||||
assert result['core']['outlier_count'] == 0
|
||||
assert 'a' in result['core']['grammar']
|
||||
|
||||
|
||||
def test_core_coverage_ratio():
|
||||
seqs = [
|
||||
['a', 'b', 'c'],
|
||||
['a', 'b', 'c'],
|
||||
['a', 'b', 'c', 'd'],
|
||||
['a', 'b', 'c', 'd', 'e'],
|
||||
]
|
||||
result = infer_ensemble(seqs, min_coverage=0.7)
|
||||
if 'core' in result:
|
||||
c = result['core']
|
||||
assert c['outlier_count'] >= 1
|
||||
assert len(c['outliers']) >= 1
|
||||
assert c['coverage'] >= 0.5
|
||||
|
||||
|
||||
def test_core_empty_sequences():
|
||||
result = infer_ensemble([], min_coverage=0.8)
|
||||
assert 'core' in result
|
||||
assert result['core']['grammar'] is not None
|
||||
|
||||
|
||||
def run_all():
|
||||
tests = [
|
||||
test_ensemble_returns_dict,
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue