Unsupervised segment discovery using divergence-based decision trees inspired by Random Forests
Project description
Segmentation Forests
Unsupervised segment discovery using divergence-based decision trees inspired by Random Forests.
Automatically discover meaningful segments in your data where metric distributions significantly diverge from the background. Perfect for exploratory data analysis, anomaly detection, and customer segmentation without requiring labeled data.
🎯 The Problem
You have a dataset with many categorical features and a metric you care about. For example:
- E-commerce: Users across countries, devices, age groups → conversion rate
- Digital advertising: Impressions across demographics, platforms, times → CTR
- Healthcare: Patients across conditions, treatments, demographics → readmission rate
- Finance: Transactions across customer segments, times → fraud rate
The question: Which specific combinations of features exhibit unusual behavior?
With N features and many values per feature, exhaustively testing combinations is impossible. Segmentation Forests solves this by intelligently searching for segments where your metric's distribution differs significantly from the overall population.
✨ Key Features
- 🌳 Tree-based discovery: Greedy algorithm efficiently navigates combinatorial feature space
- 🌲 Forest ensemble: Bootstrap aggregating for robust, reproducible patterns
- 📊 Statistical rigor: KS distance (continuous) & Jensen-Shannon divergence (discrete)
- 📈 Beautiful visualizations: Distribution comparisons and quality assessments
- 🔬 Fully typed: Complete type hints for excellent IDE support
- ⚡ Fast & scalable: Handles datasets with millions of rows
🚀 Quick Start
Installation
pip install segmentation-forests
Basic Example
import pandas as pd
from segmentation_forests import SegmentationTree, SegmentationForest
# Your data: features + metric
data = pd.DataFrame({
'country': ['US', 'UK', 'US', 'UK', ...],
'device': ['Mobile', 'Desktop', 'Mobile', ...],
'gender': ['F', 'M', 'F', ...],
'impressions': [245, 103, 312, 98, ...] # Your metric
})
# Discover segments with a single tree
tree = SegmentationTree(max_depth=3, min_samples_split=100)
tree.fit(data, metric_column='impressions')
segments = tree.get_segments(min_divergence=0.1)
# View results
for i, seg in enumerate(segments[:3], 1):
print(f"{i}. {seg.get_condition_string()}")
print(f" Divergence: {seg.divergence:.3f} | Size: {seg.size:,}")
Output:
1. gender == F AND device == Mobile AND country == UK
Divergence: 0.948 | Size: 523
2. time_of_day == Evening AND country == US AND device == Desktop
Divergence: 0.856 | Size: 412
3. country == DE AND time_of_day == Morning
Divergence: 0.824 | Size: 289
🌲 Using the Forest (Recommended)
For more robust results, use the ensemble approach:
from segmentation_forests import SegmentationForest
# Create forest with bootstrap sampling and random features
forest = SegmentationForest(
n_trees=10,
max_depth=3,
max_features=2, # Random feature selection
min_samples_split=100,
min_samples_leaf=50
)
# Fit and get segments found by multiple trees
forest.fit(data, metric_column='impressions')
robust_segments = forest.get_segments(min_support=3, min_divergence=0.1)
# View results
for seg in robust_segments:
cond_str = " AND ".join([f"{c[0]} {c[1]} {c[2]}" for c in seg['conditions']])
print(f"{cond_str}")
print(f" Support: {seg['support']}/10 trees ({seg['support_rate']*100:.0f}%)")
print(f" Avg Divergence: {seg['avg_divergence']:.3f}")
print()
📊 Visualization
Beautiful distribution comparison plots:
from segmentation_forests.visualization import plot_segment_comparison
# Compare segment distribution vs background
fig = plot_segment_comparison(
data=data,
segment_conditions=[('country', '==', 'UK'), ('device', '==', 'Mobile')],
metric_column='impressions',
title='UK Mobile Users vs Background'
)
fig.savefig('segment_comparison.png', dpi=150)
Example output:
The plot shows:
- Left: Overlapping histograms (background in blue, segment in coral)
- Right: Box plots comparing distributions
- Clear separation: Strong segments show minimal overlap
🧠 How It Works
Algorithm Overview
- Compute Background Distribution: Calculate the distribution of your metric across all data
- Greedy Tree Building:
- At each node, evaluate all feature-value splits
- Choose the split that maximizes divergence from background
- Recursively build left (matching condition) and right (not matching) subtrees
- Collect High-Divergence Leaves: Return segments that diverge significantly
- Ensemble Aggregation (Forest only): Vote across trees to find robust patterns
Divergence Measures
The algorithm automatically selects the appropriate measure:
| Metric Type | Measure | Range | Description |
|---|---|---|---|
| Continuous | Kolmogorov-Smirnov | [0, 1] | Max distance between CDFs |
| Discrete | Jensen-Shannon | [0, 1] | Symmetric KL divergence |
Decision threshold: ≤20 unique values → discrete, >20 → continuous
Quality Guidelines
Interpret divergence scores:
- ≥ 0.5: 🎯 Excellent - Strong, highly actionable pattern
- 0.3-0.5: ✓ Good - Meaningful difference worth investigating
- 0.1-0.3: ⚠️ Weak - Marginal effect, could be noise
- < 0.1: ❌ Very weak - Likely statistical noise
📖 API Reference
SegmentationTree
SegmentationTree(
max_depth: int = 5,
min_samples_split: int = 50,
min_samples_leaf: int = 20,
divergence_threshold: float = 0.01,
random_features: Optional[int] = None
)
Parameters:
max_depth: Maximum tree depth (controls segment complexity)min_samples_split: Minimum samples required to split a nodemin_samples_leaf: Minimum samples required in each childdivergence_threshold: Minimum divergence to keep a segmentrandom_features: Number of random features per split (None = use all)
Methods:
fit(data: pd.DataFrame, metric_column: str) -> Self: Fit tree to dataget_segments(min_divergence: float = 0.0) -> List[SegmentationNode]: Get segments
SegmentationForest
SegmentationForest(
n_trees: int = 10,
max_depth: int = 5,
min_samples_split: int = 50,
min_samples_leaf: int = 20,
divergence_threshold: float = 0.01,
max_features: Optional[int] = None
)
Parameters:
n_trees: Number of trees in the forest- Other parameters same as
SegmentationTree
Methods:
fit(data: pd.DataFrame, metric_column: str) -> Self: Fit forestget_segments(min_support: int = 2, min_divergence: float = 0.0) -> List[Dict]: Get robust segments
Returns: List of dicts with keys:
conditions: List of (column, operator, value) tuplessupport: Number of trees that found this segmentavg_divergence: Average divergence across treesavg_size: Average segment sizesupport_rate: Fraction of trees (support / n_trees)
SegmentationNode
Represents a discovered segment.
Attributes:
conditions: List of (column, operator, value) tuplesdivergence: Divergence scoresize: Number of data pointsdepth: Depth in treedata_indices: Indices of data points in this segment
Methods:
get_condition_string() -> str: Human-readable condition string
🎨 Visualization Functions
plot_segment_comparison
plot_segment_comparison(
data: pd.DataFrame,
segment_conditions: List[Tuple],
metric_column: str,
title: Optional[str] = None,
figsize: Tuple = (14, 5)
) -> plt.Figure
Creates side-by-side histogram and box plot comparison.
💡 Usage Tips
Choosing Parameters
For max_depth:
depth=2: Simple 2-condition segments (e.g., "Country=UK AND Device=Mobile")depth=3-4: Recommended - Balanced complexitydepth=5+: Complex segments, risk of overfitting
For min_divergence:
- Start with
0.1to see all interesting patterns - Increase to
0.3+to focus only on strong effects - Use forest
min_supportto filter noise instead
For forest:
n_trees=10: Good defaultn_trees=20+: More robust but slowermax_features=sqrt(n_features): Good for high-dimensional data
Interpreting Results
- Always visualize top segments to verify they make sense
- Check segment size - very small segments may be spurious
- Use forest support - patterns in 5+/10 trees are highly reliable
- Domain validation - do discovered segments align with business intuition?
🔬 Example: Advertising Dataset
from segmentation_forests import SegmentationForest
from segmentation_forests.visualization import plot_segment_comparison
import pandas as pd
import numpy as np
# Create synthetic advertising data
np.random.seed(42)
n = 10000
data = pd.DataFrame({
'country': np.random.choice(['US', 'UK', 'CA', 'DE', 'FR'], n),
'device': np.random.choice(['Mobile', 'Desktop', 'Tablet'], n),
'gender': np.random.choice(['M', 'F'], n),
'time_of_day': np.random.choice(['Morning', 'Afternoon', 'Evening', 'Night'], n),
'impressions': np.random.poisson(100, n) # Base: ~100 impressions
})
# Add hidden pattern: UK females on mobile get 3x impressions
mask = (data['gender'] == 'F') & (data['country'] == 'UK') & (data['device'] == 'Mobile')
data.loc[mask, 'impressions'] = np.random.poisson(300, mask.sum())
# Discover the pattern
forest = SegmentationForest(n_trees=10, max_depth=3, max_features=2)
forest.fit(data, 'impressions')
segments = forest.get_segments(min_support=3, min_divergence=0.3)
# Result: Discovers the hidden pattern!
# Output: "gender == F AND country == UK AND device == Mobile"
# Divergence: 0.948, Support: 7/10 trees
See examples/advertising_example.py for the complete example.
🛠️ Development
Setup
# Clone repository
git clone https://github.com/davidgeorgewilliams/segmentation-forests.git
cd segmentation-forests
# Install with dev dependencies
pip install -e ".[dev]"
# Install pre-commit hooks
pre-commit install
Running Tests
# Run all tests
pytest
# With coverage
pytest --cov=segmentation_forests --cov-report=html
# Run specific test
pytest tests/test_tree.py -v
Code Quality
# Format code
black src/ tests/
isort src/ tests/
# Lint
ruff check src/ tests/
# Type check
mypy src/
🤝 Contributing
Contributions are welcome! Please:
- Fork the repository
- Create a feature branch (
git checkout -b feature/amazing-feature) - Make your changes and add tests
- Ensure all tests pass and code is formatted
- Submit a pull request
📄 License
This project is licensed under the MIT License - see the LICENSE file for details.
📚 Citation
If you use Segmentation Forests in your research or project, please cite:
@software{segmentation_forests,
author = {Williams, David},
title = {Segmentation Forests: Unsupervised Segment Discovery using Divergence-based Decision Trees},
year = {2025},
url = {https://github.com/davidgeorgewilliams/segmentation-forests}
}
🙏 Acknowledgments
- Algorithm inspired by Random Forests (Breiman, 2001)
- Divergence measures from information theory (Kullback-Leibler, Jensen-Shannon)
- Built with NumPy, pandas, SciPy, matplotlib, and seaborn
📞 Contact
David Williams - david@davidgeorgewilliams.com
Project Link: https://github.com/davidgeorgewilliams/segmentation-forests
Happy Discovering! 🎯🌲
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file segmentation_forests-0.1.0.tar.gz.
File metadata
- Download URL: segmentation_forests-0.1.0.tar.gz
- Upload date:
- Size: 18.8 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
b267564400c65a56f257a61f230806b7cf4ac2bbdca6f364b0130244e2f5a31b
|
|
| MD5 |
a8b39edc1f35b61cbf44450ed9c06cb1
|
|
| BLAKE2b-256 |
fe5357b88e89061d70701b69321c40b3a6a0f77a6e4a1917d430af125f7decb0
|
File details
Details for the file segmentation_forests-0.1.0-py3-none-any.whl.
File metadata
- Download URL: segmentation_forests-0.1.0-py3-none-any.whl
- Upload date:
- Size: 14.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.2.0 CPython/3.12.7
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
a4c29983caab4a2153a362c7ed209b67a9cdc34536ba49982224205e2f27089a
|
|
| MD5 |
c11186c1abdd157fff250cbcdf4237e4
|
|
| BLAKE2b-256 |
77f4d9a51443e6fbeb45d21b9cd35397742caddff78640c268491bbe727b5d52
|