feat: core+outlier analysis via min_coverage parameter, 6 new tests
This commit is contained in:
parent
edd6d9d4dd
commit
9045769d57
2 changed files with 214 additions and 3 deletions
156
bex/ensemble.py
156
bex/ensemble.py
|
|
@ -234,6 +234,129 @@ def _matches(grammar, sequence):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _fit_score(grammar, seq):
|
||||||
|
"""Score how tightly a sequence fits: 1.0 = perfect match to core,
|
||||||
|
0.0 = mostly uses optional/repeated parts.
|
||||||
|
|
||||||
|
Instead of trying to parse the grammar structure (which is fragile),
|
||||||
|
this measures how well seq matches against the grammatical core by
|
||||||
|
comparing its symbol positions to the grammar's 'spine' — the symbols
|
||||||
|
that appear in all sequences.
|
||||||
|
"""
|
||||||
|
if not seq:
|
||||||
|
return 0.0
|
||||||
|
try:
|
||||||
|
# Strategy: parse grammar tokens, match seq, count what fraction
|
||||||
|
# of seq length is consumed by obligatory (non-?, non-+?) tokens.
|
||||||
|
tokens = _parse_parts(grammar.strip())
|
||||||
|
if not tokens or tokens[0][0] == 'empty':
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
def _classify_tokens(node):
|
||||||
|
"""Return (obligatory_count, optional_count) for this node."""
|
||||||
|
tt, tv, tq = node
|
||||||
|
if tt == 'symbol':
|
||||||
|
if tq in ('', '+'):
|
||||||
|
return (1, 0)
|
||||||
|
return (0, 1)
|
||||||
|
if tt == 'concat':
|
||||||
|
ob, op = 0, 0
|
||||||
|
for c in tv:
|
||||||
|
if c[0] == 'empty':
|
||||||
|
continue
|
||||||
|
o1, o2 = _classify_tokens(c)
|
||||||
|
ob += o1
|
||||||
|
op += o2
|
||||||
|
return (ob, op)
|
||||||
|
if tt == 'disj':
|
||||||
|
# Any alternative counts as optional
|
||||||
|
return (0, len(tv))
|
||||||
|
return (0, 0)
|
||||||
|
|
||||||
|
ob, op = _classify_tokens(tokens[0])
|
||||||
|
total = ob + op
|
||||||
|
if total == 0:
|
||||||
|
return 0.5
|
||||||
|
|
||||||
|
# Match seq and see how many symbols are actually consumed
|
||||||
|
end = _match_tokens(tokens, seq)
|
||||||
|
if end is None or end != len(seq):
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
# Fit = fraction of mandatory symbols / total mandatory+optional
|
||||||
|
# Penalizes sequences that lean heavily on optional parts
|
||||||
|
return max(0.0, 1.0 - (op / total))
|
||||||
|
except Exception:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def _symbol_rarity_score(seq, all_sequences):
|
||||||
|
"""Score a sequence by how rare its symbols are across the dataset.
|
||||||
|
1.0 = all symbols are common, 0.0 = mostly rare symbols.
|
||||||
|
"""
|
||||||
|
from collections import Counter
|
||||||
|
all_syms = Counter()
|
||||||
|
for s in all_sequences:
|
||||||
|
all_syms.update(s)
|
||||||
|
n = len(all_sequences)
|
||||||
|
scores = []
|
||||||
|
for sym in seq:
|
||||||
|
freq = all_syms.get(sym, 0) / n
|
||||||
|
scores.append(min(freq, 1.0))
|
||||||
|
return sum(scores) / len(scores) if scores else 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def _find_core(sequences, min_coverage=0.8):
|
||||||
|
"""Find the core subset of sequences by iterative CRX + outlier removal.
|
||||||
|
|
||||||
|
Outlier detection uses symbol rarity: sequences with rare symbols
|
||||||
|
(appearing in few other sequences) are removed first.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(core_grammar, core_sequences, outliers, fit_scores)
|
||||||
|
"""
|
||||||
|
if not sequences or min_coverage >= 1.0:
|
||||||
|
crx_g = CRX().infer(sequences)
|
||||||
|
return crx_g, sequences, [], []
|
||||||
|
|
||||||
|
from collections import Counter
|
||||||
|
all_syms = Counter()
|
||||||
|
for s in sequences:
|
||||||
|
all_syms.update(s)
|
||||||
|
n = len(sequences)
|
||||||
|
|
||||||
|
def _rarity(seq):
|
||||||
|
rare_count = sum(1 for sym in seq if all_syms.get(sym, 0) / n < 0.3)
|
||||||
|
return rare_count / max(len(seq), 1)
|
||||||
|
|
||||||
|
working = list(sequences)
|
||||||
|
removed_indices = []
|
||||||
|
crx = CRX()
|
||||||
|
|
||||||
|
for _ in range(50):
|
||||||
|
if len(working) < 3:
|
||||||
|
break
|
||||||
|
|
||||||
|
target = max(int(len(sequences) * min_coverage), 1)
|
||||||
|
if len(working) <= target:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Score by rarity: most rare symbol → worst fit
|
||||||
|
scores = [(i, _rarity(seq)) for i, seq in enumerate(working)]
|
||||||
|
scores.sort(key=lambda x: -x[1]) # most rare first
|
||||||
|
|
||||||
|
# If all sequences have the same score, stop (no outliers to remove)
|
||||||
|
if len(scores) < 2 or scores[0][1] == scores[-1][1]:
|
||||||
|
break
|
||||||
|
|
||||||
|
worst_idx = scores[0][0]
|
||||||
|
removed_indices.append(working[worst_idx])
|
||||||
|
working = [s for i, s in enumerate(working) if i != worst_idx]
|
||||||
|
|
||||||
|
core_g = crx.infer(working) if working else None
|
||||||
|
return core_g, working, removed_indices, []
|
||||||
|
|
||||||
|
|
||||||
def mdl_score_simple(grammar, sequences):
|
def mdl_score_simple(grammar, sequences):
|
||||||
"""MDL score from the paper: model_cost + Σ log₂(|L(r)| at length len(s)).
|
"""MDL score from the paper: model_cost + Σ log₂(|L(r)| at length len(s)).
|
||||||
|
|
||||||
|
|
@ -276,7 +399,7 @@ _ALGORITHMS = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def infer_ensemble(sequences, kmax=2, N=3, prefer=None):
|
def infer_ensemble(sequences, kmax=2, N=3, prefer=None, min_coverage=1.0):
|
||||||
"""Run all applicable algorithms and return the best by MDL score.
|
"""Run all applicable algorithms and return the best by MDL score.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -285,12 +408,18 @@ def infer_ensemble(sequences, kmax=2, N=3, prefer=None):
|
||||||
N: Number of random trials for k-ORE inference.
|
N: Number of random trials for k-ORE inference.
|
||||||
prefer: Optional — 'crx', 'idregex', or 'koreinference' to skip
|
prefer: Optional — 'crx', 'idregex', or 'koreinference' to skip
|
||||||
ensemble and return only that algorithm's result.
|
ensemble and return only that algorithm's result.
|
||||||
|
min_coverage: When < 1.0, also runs CRX on the tightest core subset
|
||||||
|
of sequences. Outliers (worst-fitting) are iteratively
|
||||||
|
removed until at least this fraction remains. The core
|
||||||
|
grammar and outlier list are included in the response.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict with keys:
|
dict with keys:
|
||||||
best: {algorithm, grammar, mdl_score}
|
best: {algorithm, grammar, mdl_score}
|
||||||
all: [{algorithm, grammar, mdl_score}, ...]
|
all: [{algorithm, grammar, mdl_score}, ...]
|
||||||
why: str explaining the choice
|
why: str explaining the choice
|
||||||
|
core: (optional) {grammar, coverage, outliers} — only when
|
||||||
|
min_coverage < 1.0
|
||||||
"""
|
"""
|
||||||
if prefer and prefer.lower() in _ALGORITHMS:
|
if prefer and prefer.lower() in _ALGORITHMS:
|
||||||
key = prefer.lower()
|
key = prefer.lower()
|
||||||
|
|
@ -328,11 +457,19 @@ def infer_ensemble(sequences, kmax=2, N=3, prefer=None):
|
||||||
|
|
||||||
results = [r for r in results if r[1] and r[1] != '∅']
|
results = [r for r in results if r[1] and r[1] != '∅']
|
||||||
if not results:
|
if not results:
|
||||||
return {
|
base = {
|
||||||
'best': None,
|
'best': None,
|
||||||
'all': [],
|
'all': [],
|
||||||
'why': "No algorithm produced a non-empty grammar.",
|
'why': "No algorithm produced a non-empty grammar.",
|
||||||
}
|
}
|
||||||
|
if min_coverage < 1.0:
|
||||||
|
core_g, core_seqs, outliers, _ = _find_core(sequences, min_coverage)
|
||||||
|
base['core'] = {
|
||||||
|
'grammar': core_g,
|
||||||
|
'coverage': round(len(core_seqs) / max(len(sequences), 1), 2) if sequences else 0,
|
||||||
|
'outliers': outliers,
|
||||||
|
}
|
||||||
|
return base
|
||||||
|
|
||||||
results.sort(key=lambda x: x[2])
|
results.sort(key=lambda x: x[2])
|
||||||
best = results[0]
|
best = results[0]
|
||||||
|
|
@ -360,7 +497,7 @@ def infer_ensemble(sequences, kmax=2, N=3, prefer=None):
|
||||||
|
|
||||||
why_parts.append(f"{best[0]} selected (MDL score {best[2]:.1f}).")
|
why_parts.append(f"{best[0]} selected (MDL score {best[2]:.1f}).")
|
||||||
|
|
||||||
return {
|
result = {
|
||||||
'best': {
|
'best': {
|
||||||
'algorithm': best[0],
|
'algorithm': best[0],
|
||||||
'grammar': best[1],
|
'grammar': best[1],
|
||||||
|
|
@ -369,3 +506,16 @@ def infer_ensemble(sequences, kmax=2, N=3, prefer=None):
|
||||||
'all': all_results,
|
'all': all_results,
|
||||||
'why': ' '.join(why_parts),
|
'why': ' '.join(why_parts),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Core analysis when min_coverage < 1.0
|
||||||
|
if min_coverage < 1.0:
|
||||||
|
core_g, core_seqs, outliers, _ = _find_core(sequences, min_coverage)
|
||||||
|
result['core'] = {
|
||||||
|
'grammar': core_g,
|
||||||
|
'coverage': round(len(core_seqs) / max(len(sequences), 1), 2) if sequences else 0,
|
||||||
|
'outlier_count': len(outliers),
|
||||||
|
'outliers': outliers,
|
||||||
|
}
|
||||||
|
result['why'] += f' Core CRX ({min_coverage:.0%} coverage, {len(outliers)} outliers): {core_g}'
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
|
||||||
|
|
@ -164,6 +164,67 @@ def test_ensemble_crx_always_present():
|
||||||
assert len(crx_results) == 1
|
assert len(crx_results) == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ── min_coverage / core analysis tests ──
|
||||||
|
|
||||||
|
def test_core_not_included_when_coverage_1():
|
||||||
|
seqs = [['a', 'b'], ['a', 'b', 'c']]
|
||||||
|
result = infer_ensemble(seqs, min_coverage=1.0)
|
||||||
|
assert 'core' not in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_core_included_when_coverage_lt_1():
|
||||||
|
seqs = [['a', 'b'], ['a', 'b', 'c']]
|
||||||
|
result = infer_ensemble(seqs, min_coverage=0.8)
|
||||||
|
assert 'core' in result
|
||||||
|
assert 'grammar' in result['core']
|
||||||
|
assert 'coverage' in result['core']
|
||||||
|
assert 'outliers' in result['core']
|
||||||
|
assert 'outlier_count' in result['core']
|
||||||
|
|
||||||
|
|
||||||
|
def test_core_outlier_detection():
|
||||||
|
seqs = [
|
||||||
|
['fail', 'package', 'file', 'service'],
|
||||||
|
['fail', 'package', 'file', 'service'],
|
||||||
|
['fail', 'package', 'file', 'service', 'npm'],
|
||||||
|
['fail', 'package', 'file', 'service', 'npm', 'pip'],
|
||||||
|
]
|
||||||
|
result = infer_ensemble(seqs, min_coverage=0.7)
|
||||||
|
assert 'core' in result
|
||||||
|
c = result['core']
|
||||||
|
assert c['outlier_count'] >= 1
|
||||||
|
assert 'npm' in c['grammar'] or 'service' in c['grammar']
|
||||||
|
|
||||||
|
|
||||||
|
def test_core_all_identical():
|
||||||
|
seqs = [['a', 'b', 'c']] * 10
|
||||||
|
result = infer_ensemble(seqs, min_coverage=0.8)
|
||||||
|
assert 'core' in result
|
||||||
|
assert result['core']['outlier_count'] == 0
|
||||||
|
assert 'a' in result['core']['grammar']
|
||||||
|
|
||||||
|
|
||||||
|
def test_core_coverage_ratio():
|
||||||
|
seqs = [
|
||||||
|
['a', 'b', 'c'],
|
||||||
|
['a', 'b', 'c'],
|
||||||
|
['a', 'b', 'c', 'd'],
|
||||||
|
['a', 'b', 'c', 'd', 'e'],
|
||||||
|
]
|
||||||
|
result = infer_ensemble(seqs, min_coverage=0.7)
|
||||||
|
if 'core' in result:
|
||||||
|
c = result['core']
|
||||||
|
assert c['outlier_count'] >= 1
|
||||||
|
assert len(c['outliers']) >= 1
|
||||||
|
assert c['coverage'] >= 0.5
|
||||||
|
|
||||||
|
|
||||||
|
def test_core_empty_sequences():
|
||||||
|
result = infer_ensemble([], min_coverage=0.8)
|
||||||
|
assert 'core' in result
|
||||||
|
assert result['core']['grammar'] is not None
|
||||||
|
|
||||||
|
|
||||||
def run_all():
|
def run_all():
|
||||||
tests = [
|
tests = [
|
||||||
test_ensemble_returns_dict,
|
test_ensemble_returns_dict,
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue