Skip to main content

Unsupervised segment discovery using divergence-based decision trees inspired by Random Forests

Project description

Segmentation Forests

License: MIT Python 3.12+

Unsupervised segment discovery using divergence-based decision trees inspired by Random Forests.

Automatically discover meaningful segments in your data where metric distributions significantly diverge from the background. Perfect for exploratory data analysis, anomaly detection, and customer segmentation without requiring labeled data.


🎯 The Problem

You have a dataset with many categorical features and a metric you care about. For example:

  • E-commerce: Users across countries, devices, age groups → conversion rate
  • Digital advertising: Impressions across demographics, platforms, times → CTR
  • Healthcare: Patients across conditions, treatments, demographics → readmission rate
  • Finance: Transactions across customer segments, times → fraud rate

The question: Which specific combinations of features exhibit unusual behavior?

With N features and many values per feature, exhaustively testing combinations is impossible. Segmentation Forests solves this by intelligently searching for segments where your metric's distribution differs significantly from the overall population.


✨ Key Features

  • 🌳 Tree-based discovery: Greedy algorithm efficiently navigates combinatorial feature space
  • 🌲 Forest ensemble: Bootstrap aggregating for robust, reproducible patterns
  • 📊 Statistical rigor: KS distance (continuous) & Jensen-Shannon divergence (discrete)
  • 📈 Beautiful visualizations: Distribution comparisons and quality assessments
  • 🔬 Fully typed: Complete type hints for excellent IDE support
  • Fast & scalable: Handles datasets with millions of rows

🚀 Quick Start

Installation

pip install segmentation-forests

Basic Example

import pandas as pd
from segmentation_forests import SegmentationTree, SegmentationForest

# Your data: features + metric
data = pd.DataFrame({
    'country': ['US', 'UK', 'US', 'UK', ...],
    'device': ['Mobile', 'Desktop', 'Mobile', ...],
    'gender': ['F', 'M', 'F', ...],
    'impressions': [245, 103, 312, 98, ...]  # Your metric
})

# Discover segments with a single tree
tree = SegmentationTree(max_depth=3, min_samples_split=100)
tree.fit(data, metric_column='impressions')
segments = tree.get_segments(min_divergence=0.1)

# View results
for i, seg in enumerate(segments[:3], 1):
    print(f"{i}. {seg.get_condition_string()}")
    print(f"   Divergence: {seg.divergence:.3f} | Size: {seg.size:,}")

Output:

1. gender == F AND device == Mobile AND country == UK
   Divergence: 0.948 | Size: 523

2. time_of_day == Evening AND country == US AND device == Desktop
   Divergence: 0.856 | Size: 412

3. country == DE AND time_of_day == Morning
   Divergence: 0.824 | Size: 289

🌲 Using the Forest (Recommended)

For more robust results, use the ensemble approach:

from segmentation_forests import SegmentationForest

# Create forest with bootstrap sampling and random features
forest = SegmentationForest(
    n_trees=10,
    max_depth=3,
    max_features=2,  # Random feature selection
    min_samples_split=100,
    min_samples_leaf=50
)

# Fit and get segments found by multiple trees
forest.fit(data, metric_column='impressions')
robust_segments = forest.get_segments(min_support=3, min_divergence=0.1)

# View results
for seg in robust_segments:
    cond_str = " AND ".join([f"{c[0]} {c[1]} {c[2]}" for c in seg['conditions']])
    print(f"{cond_str}")
    print(f"  Support: {seg['support']}/10 trees ({seg['support_rate']*100:.0f}%)")
    print(f"  Avg Divergence: {seg['avg_divergence']:.3f}")
    print()

📊 Visualization

Beautiful distribution comparison plots:

from segmentation_forests.visualization import plot_segment_comparison

# Compare segment distribution vs background
fig = plot_segment_comparison(
    data=data,
    segment_conditions=[('country', '==', 'UK'), ('device', '==', 'Mobile')],
    metric_column='impressions',
    title='UK Mobile Users vs Background'
)
fig.savefig('segment_comparison.png', dpi=150)

Example output:

The plot shows:

  • Left: Overlapping histograms (background in blue, segment in coral)
  • Right: Box plots comparing distributions
  • Clear separation: Strong segments show minimal overlap

🧠 How It Works

Algorithm Overview

  1. Compute Background Distribution: Calculate the distribution of your metric across all data
  2. Greedy Tree Building:
    • At each node, evaluate all feature-value splits
    • Choose the split that maximizes divergence from background
    • Recursively build left (matching condition) and right (not matching) subtrees
  3. Collect High-Divergence Leaves: Return segments that diverge significantly
  4. Ensemble Aggregation (Forest only): Vote across trees to find robust patterns

Divergence Measures

The algorithm automatically selects the appropriate measure:

Metric Type Measure Range Description
Continuous Kolmogorov-Smirnov [0, 1] Max distance between CDFs
Discrete Jensen-Shannon [0, 1] Symmetric KL divergence

Decision threshold: ≤20 unique values → discrete, >20 → continuous

Quality Guidelines

Interpret divergence scores:

  • ≥ 0.5: 🎯 Excellent - Strong, highly actionable pattern
  • 0.3-0.5: ✓ Good - Meaningful difference worth investigating
  • 0.1-0.3: ⚠️ Weak - Marginal effect, could be noise
  • < 0.1: ❌ Very weak - Likely statistical noise

📖 API Reference

SegmentationTree

SegmentationTree(
    max_depth: int = 5,
    min_samples_split: int = 50,
    min_samples_leaf: int = 20,
    divergence_threshold: float = 0.01,
    random_features: Optional[int] = None
)

Parameters:

  • max_depth: Maximum tree depth (controls segment complexity)
  • min_samples_split: Minimum samples required to split a node
  • min_samples_leaf: Minimum samples required in each child
  • divergence_threshold: Minimum divergence to keep a segment
  • random_features: Number of random features per split (None = use all)

Methods:

  • fit(data: pd.DataFrame, metric_column: str) -> Self: Fit tree to data
  • get_segments(min_divergence: float = 0.0) -> List[SegmentationNode]: Get segments

SegmentationForest

SegmentationForest(
    n_trees: int = 10,
    max_depth: int = 5,
    min_samples_split: int = 50,
    min_samples_leaf: int = 20,
    divergence_threshold: float = 0.01,
    max_features: Optional[int] = None
)

Parameters:

  • n_trees: Number of trees in the forest
  • Other parameters same as SegmentationTree

Methods:

  • fit(data: pd.DataFrame, metric_column: str) -> Self: Fit forest
  • get_segments(min_support: int = 2, min_divergence: float = 0.0) -> List[Dict]: Get robust segments

Returns: List of dicts with keys:

  • conditions: List of (column, operator, value) tuples
  • support: Number of trees that found this segment
  • avg_divergence: Average divergence across trees
  • avg_size: Average segment size
  • support_rate: Fraction of trees (support / n_trees)

SegmentationNode

Represents a discovered segment.

Attributes:

  • conditions: List of (column, operator, value) tuples
  • divergence: Divergence score
  • size: Number of data points
  • depth: Depth in tree
  • data_indices: Indices of data points in this segment

Methods:

  • get_condition_string() -> str: Human-readable condition string

🎨 Visualization Functions

plot_segment_comparison

plot_segment_comparison(
    data: pd.DataFrame,
    segment_conditions: List[Tuple],
    metric_column: str,
    title: Optional[str] = None,
    figsize: Tuple = (14, 5)
) -> plt.Figure

Creates side-by-side histogram and box plot comparison.


💡 Usage Tips

Choosing Parameters

For max_depth:

  • depth=2: Simple 2-condition segments (e.g., "Country=UK AND Device=Mobile")
  • depth=3-4: Recommended - Balanced complexity
  • depth=5+: Complex segments, risk of overfitting

For min_divergence:

  • Start with 0.1 to see all interesting patterns
  • Increase to 0.3+ to focus only on strong effects
  • Use forest min_support to filter noise instead

For forest:

  • n_trees=10: Good default
  • n_trees=20+: More robust but slower
  • max_features=sqrt(n_features): Good for high-dimensional data

Interpreting Results

  1. Always visualize top segments to verify they make sense
  2. Check segment size - very small segments may be spurious
  3. Use forest support - patterns in 5+/10 trees are highly reliable
  4. Domain validation - do discovered segments align with business intuition?

🔬 Example: Advertising Dataset

from segmentation_forests import SegmentationForest
from segmentation_forests.visualization import plot_segment_comparison
import pandas as pd
import numpy as np

# Create synthetic advertising data
np.random.seed(42)
n = 10000

data = pd.DataFrame({
    'country': np.random.choice(['US', 'UK', 'CA', 'DE', 'FR'], n),
    'device': np.random.choice(['Mobile', 'Desktop', 'Tablet'], n),
    'gender': np.random.choice(['M', 'F'], n),
    'time_of_day': np.random.choice(['Morning', 'Afternoon', 'Evening', 'Night'], n),
    'impressions': np.random.poisson(100, n)  # Base: ~100 impressions
})

# Add hidden pattern: UK females on mobile get 3x impressions
mask = (data['gender'] == 'F') & (data['country'] == 'UK') & (data['device'] == 'Mobile')
data.loc[mask, 'impressions'] = np.random.poisson(300, mask.sum())

# Discover the pattern
forest = SegmentationForest(n_trees=10, max_depth=3, max_features=2)
forest.fit(data, 'impressions')
segments = forest.get_segments(min_support=3, min_divergence=0.3)

# Result: Discovers the hidden pattern!
# Output: "gender == F AND country == UK AND device == Mobile"
# Divergence: 0.948, Support: 7/10 trees

See examples/advertising_example.py for the complete example.


🛠️ Development

Setup

# Clone repository
git clone https://github.com/davidgeorgewilliams/segmentation-forests.git
cd segmentation-forests

# Install with dev dependencies
pip install -e ".[dev]"

# Install pre-commit hooks
pre-commit install

Running Tests

# Run all tests
pytest

# With coverage
pytest --cov=segmentation_forests --cov-report=html

# Run specific test
pytest tests/test_tree.py -v

Code Quality

# Format code
black src/ tests/
isort src/ tests/

# Lint
ruff check src/ tests/

# Type check
mypy src/

🤝 Contributing

Contributions are welcome! Please:

  1. Fork the repository
  2. Create a feature branch (git checkout -b feature/amazing-feature)
  3. Make your changes and add tests
  4. Ensure all tests pass and code is formatted
  5. Submit a pull request

📄 License

This project is licensed under the MIT License - see the LICENSE file for details.


📚 Citation

If you use Segmentation Forests in your research or project, please cite:

@software{segmentation_forests,
  author = {Williams, David},
  title = {Segmentation Forests: Unsupervised Segment Discovery using Divergence-based Decision Trees},
  year = {2025},
  url = {https://github.com/davidgeorgewilliams/segmentation-forests}
}

🙏 Acknowledgments

  • Algorithm inspired by Random Forests (Breiman, 2001)
  • Divergence measures from information theory (Kullback-Leibler, Jensen-Shannon)
  • Built with NumPy, pandas, SciPy, matplotlib, and seaborn

📞 Contact

David Williams - david@davidgeorgewilliams.com

Project Link: https://github.com/davidgeorgewilliams/segmentation-forests


Happy Discovering! 🎯🌲

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

segmentation_forests-0.1.0.tar.gz (18.8 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

segmentation_forests-0.1.0-py3-none-any.whl (14.9 kB view details)

Uploaded Python 3

File details

Details for the file segmentation_forests-0.1.0.tar.gz.

File metadata

  • Download URL: segmentation_forests-0.1.0.tar.gz
  • Upload date:
  • Size: 18.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.2.0 CPython/3.12.7

File hashes

Hashes for segmentation_forests-0.1.0.tar.gz
Algorithm Hash digest
SHA256 b267564400c65a56f257a61f230806b7cf4ac2bbdca6f364b0130244e2f5a31b
MD5 a8b39edc1f35b61cbf44450ed9c06cb1
BLAKE2b-256 fe5357b88e89061d70701b69321c40b3a6a0f77a6e4a1917d430af125f7decb0

See more details on using hashes here.

File details

Details for the file segmentation_forests-0.1.0-py3-none-any.whl.

File metadata

File hashes

Hashes for segmentation_forests-0.1.0-py3-none-any.whl
Algorithm Hash digest
SHA256 a4c29983caab4a2153a362c7ed209b67a9cdc34536ba49982224205e2f27089a
MD5 c11186c1abdd157fff250cbcdf4237e4
BLAKE2b-256 77f4d9a51443e6fbeb45d21b9cd35397742caddff78640c268491bbe727b5d52

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page