A production-ready deep learning framework for causal inference on structured, textual, and heterogeneous data
Project description
๐ง CANS: Causal Adaptive Neural System
CANS (Causal Adaptive Neural System) is a production-ready deep learning framework for causal inference on structured, textual, and heterogeneous data. It seamlessly integrates Graph Neural Networks (GNNs), Transformers (e.g., BERT), a Gated Fusion mechanism, and Counterfactual Regression Networks (CFRNet).
Initially developed for misinformation propagation on social networks, CANS generalizes to domains like healthcare, legal, and finance, offering a robust, well-tested pipeline for real-world causal modeling and counterfactual simulation.
๐ What's New in v3.0 - Enhanced Causal Inference
- ๐ฌ Causal Assumption Testing: Automated testing of unconfoundedness, positivity, and SUTVA
- ๐ฏ Multiple Identification Strategies: Backdoor criterion, IPW, doubly robust estimation
- ๐ CATE Estimation: X-Learner, T-Learner, S-Learner, Neural CATE, Causal Forest
- ๐ฒ Uncertainty Quantification: Bayesian methods, ensemble approaches, conformal prediction
- ๐๏ธ Advanced Graph Construction: Multi-node, temporal, and global graph architectures
- ๐ Causal-Specific Losses: CFR, IPW, DragonNet, TARNet with representation balancing
- ๐ Comprehensive Evaluation: PEHE, policy evaluation, calibration metrics
- ๐พ Memory-Efficient Processing: Lazy loading and batch processing for large datasets
Previous v2.0 Features:
- ๐ง Configuration Management: Centralized, validated configs with JSON/YAML support
- ๐ก๏ธ Enhanced Error Handling: Comprehensive validation with informative error messages
- ๐ Logging & Checkpointing: Built-in experiment tracking with automatic model saving
- ๐งช Comprehensive Testing: 100+ unit tests ensuring production reliability
- ๐ Advanced Data Pipeline: Multi-format loading (CSV, JSON) with automatic preprocessing
- โก Enhanced Training: Early stopping, gradient clipping, multiple loss functions
๐ง Key Features
Core Architecture
- โ Hybrid Neural Architecture: GNNs + Transformers + CFRNet for multi-modal causal inference
- โ Gated Fusion Layer: Adaptive mixing of graph and textual representations
- โ Flexible Graph Construction: Single-node, multi-node, temporal, and global graphs
- โ Production-Ready: Comprehensive error handling, logging, and testing
Causal Inference Capabilities
- โ Rigorous Assumption Testing: Automated validation of causal identification conditions
- โ Multiple Identification Methods: Backdoor, IPW, doubly robust, with sensitivity analysis
- โ Heterogeneous Treatment Effects: CATE estimation with 5+ methods (X/T/S-Learners, etc.)
- โ Advanced Loss Functions: CFR, DragonNet, TARNet with representation balancing
- โ Uncertainty Quantification: Bayesian, ensemble, conformal prediction approaches
Data Processing & Evaluation
- โ Smart Data Loading: CSV, JSON, synthetic data with automatic graph construction
- โ Comprehensive Evaluation: PEHE, ATE, policy value, calibration metrics
- โ Memory Efficiency: Lazy loading, batch processing for large-scale datasets
- โ Easy Configuration: JSON/YAML configs with validation and experiment tracking
๐๏ธ Architecture
+-----------+ +-----------+
| GNN Emb | | BERT Emb |
+-----------+ +-----------+
\ /
\ Fusion Layer /
\ /
+-----------+
| Fused Rep |
+-----------+
|
CFRNet
/ \
mu_0(x) mu_1(x)
๐ Enhanced Causal Analysis Workflow
Complete Example with Assumption Testing & CATE Estimation
from cans import (
CANSConfig, CANS, GCN, CANSRunner,
create_sample_dataset, get_data_loaders,
CausalAssumptionTester, CausalLossManager,
CATEManager, UncertaintyManager,
advanced_counterfactual_analysis
)
# 1. Configuration with enhanced causal features
config = CANSConfig()
config.model.gnn_type = "GCN"
config.training.loss_type = "cfr" # Causal loss function
config.data.graph_construction = "global" # Multi-node graphs
# 2. Test causal assumptions BEFORE modeling
assumption_tester = CausalAssumptionTester()
results = assumption_tester.comprehensive_test(X, T, Y)
print(f"Causal assumptions valid: {results['causal_identification_valid']}")
# 3. Create datasets with enhanced graph construction
datasets = create_sample_dataset(n_samples=1000, config=config.data)
train_loader, val_loader, test_loader = get_data_loaders(datasets)
# 4. Setup model with causal loss functions
from transformers import BertModel
gnn = GCN(in_dim=64, hidden_dim=128, output_dim=256)
bert = BertModel.from_pretrained("distilbert-base-uncased")
model = CANS(gnn, bert, fusion_dim=256)
loss_manager = CausalLossManager("cfr", alpha=1.0, beta=0.5)
# 5. Train with causal-aware pipeline
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
runner = CANSRunner(model, optimizer, config)
history = runner.fit(train_loader, val_loader)
# 6. Multiple counterfactual identification methods
cf_results = advanced_counterfactual_analysis(
model, test_loader,
methods=['backdoor', 'ipw', 'doubly_robust']
)
# 7. CATE estimation with multiple learners
cate_manager = CATEManager(method="x_learner")
cate_manager.fit(X, T, Y)
individual_effects = cate_manager.estimate_cate(X_test)
# 8. Uncertainty quantification
uncertainty_manager = UncertaintyManager(method="conformal")
uncertainty_manager.setup(model)
intervals = uncertainty_manager.estimate_uncertainty(test_loader)
print(f"ATE: {cf_results['backdoor']['ate']:.3f}")
print(f"Coverage: {intervals['coverage_rate']:.3f}")
๐ Quick Start
Installation
# Clone the repository
git clone https://github.com/rdmurugan/cans-framework.git
cd cans-framework
# Install dependencies
pip install -r cans/requirements.txt.rtf
Core Dependencies:
torch>=2.0.0transformers>=4.38.0torch-geometric>=2.3.0scikit-learn>=1.3.0pandas>=2.0.0
Basic Usage (30 seconds to results)
from cans.config import CANSConfig
from cans.utils.data import create_sample_dataset, get_data_loaders
from cans.models import CANS
from cans.models.gnn_modules import GCN
from cans.pipeline.runner import CANSRunner
from transformers import BertModel
import torch
# 1. Create configuration
config = CANSConfig()
config.training.epochs = 10
# 2. Load data (or create sample data)
datasets = create_sample_dataset(n_samples=1000, n_features=64)
train_loader, val_loader, test_loader = get_data_loaders(datasets, batch_size=32)
# 3. Create model
gnn = GCN(in_dim=64, hidden_dim=128, output_dim=256)
bert = BertModel.from_pretrained("bert-base-uncased")
model = CANS(gnn, bert, fusion_dim=256)
# 4. Train
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
runner = CANSRunner(model, optimizer, config)
history = runner.fit(train_loader, val_loader)
# 5. Evaluate
results = runner.evaluate(test_loader)
print(f"Test MSE: {results['mse']:.4f}")
print(f"Average Treatment Effect: {results['ate']:.4f}")
๐ Usage Examples
Example 1: CSV Data with Real Causal Inference
from cans.utils.data import load_csv_dataset
from cans.config import CANSConfig, DataConfig
# Configure data processing
config = CANSConfig()
config.data.graph_construction = "knn" # or "similarity"
config.data.knn_k = 5
config.data.scale_node_features = True
# Load your CSV data
datasets = load_csv_dataset(
csv_path="your_data.csv",
text_column="review_text", # Column with text data
treatment_column="intervention", # Binary treatment (0/1)
outcome_column="conversion_rate", # Continuous outcome
feature_columns=["age", "income", "education"], # Numerical features
config=config.data
)
train_dataset, val_dataset, test_dataset = datasets
# Check data quality
stats = train_dataset.get_statistics()
print(f"Treatment proportion: {stats['treatment_proportion']:.3f}")
print(f"Propensity overlap valid: {stats['propensity_overlap_valid']}")
Example 2: Advanced Configuration & Experiment Tracking
from cans.config import CANSConfig
# Create detailed configuration
config = CANSConfig()
# Model configuration
config.model.gnn_type = "GCN"
config.model.gnn_hidden_dim = 256
config.model.fusion_dim = 512
config.model.text_model = "distilbert-base-uncased" # Faster BERT variant
# Training configuration
config.training.learning_rate = 0.001
config.training.batch_size = 64
config.training.epochs = 50
config.training.early_stopping_patience = 10
config.training.gradient_clip_norm = 1.0
config.training.loss_type = "huber" # Robust to outliers
# Experiment tracking
config.experiment.experiment_name = "healthcare_causal_analysis"
config.experiment.save_every_n_epochs = 5
config.experiment.log_level = "INFO"
# Save configuration for reproducibility
config.save("experiment_config.json")
# Later: load and use
loaded_config = CANSConfig.load("experiment_config.json")
Example 3: Counterfactual Analysis & Treatment Effects
from cans.utils.causal import simulate_counterfactual
import numpy as np
# After training your model...
runner = CANSRunner(model, optimizer, config)
runner.fit(train_loader, val_loader)
# Comprehensive evaluation
test_metrics = runner.evaluate(test_loader)
print("Performance Metrics:")
for metric, value in test_metrics.items():
print(f" {metric}: {value:.4f}")
# Counterfactual analysis
cf_control = simulate_counterfactual(model, test_loader, intervention=0)
cf_treatment = simulate_counterfactual(model, test_loader, intervention=1)
# Calculate causal effects
ate = np.mean(cf_treatment) - np.mean(cf_control)
print(f"\nCausal Analysis:")
print(f"Average Treatment Effect (ATE): {ate:.4f}")
print(f"Expected outcome under control: {np.mean(cf_control):.4f}")
print(f"Expected outcome under treatment: {np.mean(cf_treatment):.4f}")
# Individual treatment effects
individual_effects = np.array(cf_treatment) - np.array(cf_control)
print(f"Treatment effect std: {np.std(individual_effects):.4f}")
print(f"% benefiting from treatment: {(individual_effects > 0).mean()*100:.1f}%")
Example 4: Custom Data Pipeline
from cans.utils.preprocessing import DataPreprocessor, GraphBuilder
from cans.config import DataConfig
import pandas as pd
# Custom preprocessing pipeline
config = DataConfig()
config.graph_construction = "similarity"
config.similarity_threshold = 0.7
config.scale_node_features = True
preprocessor = DataPreprocessor(config)
# Process your DataFrame
df = pd.read_csv("social_media_posts.csv")
dataset = preprocessor.process_tabular_data(
data=df,
text_column="post_content",
treatment_column="fact_check_label",
outcome_column="share_count",
feature_columns=["user_followers", "post_length", "sentiment_score"],
text_model="bert-base-uncased",
max_text_length=256
)
# Split with custom ratios
train_ds, val_ds, test_ds = preprocessor.split_dataset(
dataset,
train_size=0.7,
val_size=0.2,
test_size=0.1
)
๐งช Testing & Development
# Run all tests
pytest tests/ -v
# Run specific test categories
pytest tests/test_models.py -v # Model tests
pytest tests/test_validation.py -v # Validation tests
pytest tests/test_pipeline.py -v # Training pipeline tests
# Run with coverage
pytest tests/ --cov=cans --cov-report=html
# Run example scripts
python examples/enhanced_usage_example.py
python examples/enhanced_causal_analysis_example.py
๐ Framework Structure
cans-framework/
โโโ cans/
โ โโโ __init__.py # Main imports
โ โโโ config.py # โจ Configuration management
โ โโโ exceptions.py # โจ Custom exceptions
โ โโโ validation.py # โจ Data validation utilities
โ โโโ models/
โ โ โโโ cans.py # Core CANS model (enhanced)
โ โ โโโ gnn_modules.py # GNN implementations
โ โโโ pipeline/
โ โ โโโ runner.py # โจ Enhanced training pipeline
โ โโโ utils/
โ โโโ causal.py # Counterfactual simulation
โ โโโ data.py # โจ Enhanced data loading
โ โโโ preprocessing.py # โจ Advanced preprocessing
โ โโโ logging.py # โจ Structured logging
โ โโโ checkpointing.py # โจ Model checkpointing
โโโ tests/ # โจ Comprehensive test suite
โโโ examples/ # Usage examples
โโโ CLAUDE.md # Development guide
โจ = New/Enhanced in v2.0
๐ฏ Use Cases & Applications
Healthcare & Medical
# Analyze treatment effectiveness with patient records + clinical notes
datasets = load_csv_dataset(
csv_path="patient_outcomes.csv",
text_column="clinical_notes",
treatment_column="medication_type",
outcome_column="recovery_score",
feature_columns=["age", "bmi", "comorbidities"]
)
Marketing & A/B Testing
# Marketing campaign effectiveness with customer profiles + ad content
datasets = load_csv_dataset(
csv_path="campaign_data.csv",
text_column="ad_content",
treatment_column="campaign_variant",
outcome_column="conversion_rate",
feature_columns=["customer_ltv", "demographics", "behavior_score"]
)
Social Media & Content Moderation
# Impact of content moderation on engagement
datasets = load_csv_dataset(
csv_path="posts_data.csv",
text_column="post_text",
treatment_column="moderation_action",
outcome_column="engagement_score",
feature_columns=["user_followers", "post_length", "sentiment"]
)
๐ฌ Research & Methodology
CANS implements state-of-the-art causal inference techniques:
- Counterfactual Regression Networks (CFRNet): Learn representations that minimize treatment assignment bias
- Gated Fusion: Adaptively combine graph-structured and textual information
- Balanced Representation: Minimize distributional differences between treatment groups
- Propensity Score Validation: Automatic overlap checking for reliable causal estimates
Key Papers:
- Shalit et al. "Estimating individual treatment effect: generalization bounds and algorithms" (ICML 2017)
- Yao et al. "Representation learning for treatment effect estimation from observational data" (NeurIPS 2018)
๐ Performance & Scalability
- Memory Efficient: Optimized batch processing and gradient checkpointing
- GPU Acceleration: Full CUDA support with automatic device selection
- Parallel Processing: Multi-core data loading and preprocessing
- Production Ready: Comprehensive error handling and logging
Benchmarks (approximate, hardware-dependent):
- Small: 1K samples, 32 features โ ~30 sec training
- Medium: 100K samples, 128 features โ ~10 min training
- Large: 1M+ samples โ Scales with batch size and hardware
๐ Additional Resources
- Documentation: See
CLAUDE.mdfor detailed development guide - Examples: Check
examples/directory for complete workflows - Tests:
tests/contains 100+ unit tests demonstrating usage - Issues: Report bugs and feature requests on GitHub
๐ค Contributing
Contributions welcome! Please:
- Fork the repository
- Create a feature branch (
git checkout -b feature/amazing-feature) - Add tests for new functionality
- Run tests:
pytest tests/ -v - Submit a pull request
Areas we'd love help with:
- Additional GNN architectures (GraphSAGE, Graph Transformers)
- More evaluation metrics for causal inference
- Integration with popular ML platforms (MLflow, Weights & Biases)
- Performance optimizations
๐จโ๐ฌ Authors
Durai Rajamanickam โ @duraimuruganr reach out to durai@infinidatum.net
๐ License
MIT License. Free to use, modify, and distribute with attribution.
Ready to get started? Try the 30-second quick start above, or dive into the detailed examples! ๐
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
Filter files by name, interpreter, ABI, and platform.
If you're not sure about the file name format, learn more about wheel file names.
Copy a direct link to the current filters
File details
Details for the file cans_framework-3.0.0.tar.gz.
File metadata
- Download URL: cans_framework-3.0.0.tar.gz
- Upload date:
- Size: 71.0 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
9ae0264ca17216919ed638db7b9ce2cb3480bb623d1695d34f6ef0d0f4d7b3c1
|
|
| MD5 |
fcf651f85fd31255b03a45fef3ed28b9
|
|
| BLAKE2b-256 |
f9469bc5a6c4b0da1776d8f112873bb45c012958269dd8ecdca8c594a3f75c12
|
File details
Details for the file cans_framework-3.0.0-py3-none-any.whl.
File metadata
- Download URL: cans_framework-3.0.0-py3-none-any.whl
- Upload date:
- Size: 62.4 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/6.1.0 CPython/3.13.5
File hashes
| Algorithm | Hash digest | |
|---|---|---|
| SHA256 |
8a988303edf46e58924e702c931a7c3dbfc3099ec92568617fe1e4f902cc18f6
|
|
| MD5 |
12831f39ae07480827ff06fb256c27f4
|
|
| BLAKE2b-256 |
944441966929e7a15a9decf817c8232bc34d3daa7bb12d70af863443eeaf7070
|