Skip to main content

A production-ready deep learning framework for causal inference on structured, textual, and heterogeneous data

Project description

๐Ÿง  CANS: Causal Adaptive Neural System

CANS (Causal Adaptive Neural System) is a production-ready deep learning framework for causal inference on structured, textual, and heterogeneous data. It seamlessly integrates Graph Neural Networks (GNNs), Transformers (e.g., BERT), a Gated Fusion mechanism, and Counterfactual Regression Networks (CFRNet).

Initially developed for misinformation propagation on social networks, CANS generalizes to domains like healthcare, legal, and finance, offering a robust, well-tested pipeline for real-world causal modeling and counterfactual simulation.

๐Ÿš€ What's New in v3.0 - Enhanced Causal Inference

  • ๐Ÿ”ฌ Causal Assumption Testing: Automated testing of unconfoundedness, positivity, and SUTVA
  • ๐ŸŽฏ Multiple Identification Strategies: Backdoor criterion, IPW, doubly robust estimation
  • ๐Ÿ“Š CATE Estimation: X-Learner, T-Learner, S-Learner, Neural CATE, Causal Forest
  • ๐ŸŽฒ Uncertainty Quantification: Bayesian methods, ensemble approaches, conformal prediction
  • ๐Ÿ—๏ธ Advanced Graph Construction: Multi-node, temporal, and global graph architectures
  • ๐Ÿ”„ Causal-Specific Losses: CFR, IPW, DragonNet, TARNet with representation balancing
  • ๐Ÿ“ˆ Comprehensive Evaluation: PEHE, policy evaluation, calibration metrics
  • ๐Ÿ’พ Memory-Efficient Processing: Lazy loading and batch processing for large datasets

Previous v2.0 Features:

  • ๐Ÿ”ง Configuration Management: Centralized, validated configs with JSON/YAML support
  • ๐Ÿ›ก๏ธ Enhanced Error Handling: Comprehensive validation with informative error messages
  • ๐Ÿ“Š Logging & Checkpointing: Built-in experiment tracking with automatic model saving
  • ๐Ÿงช Comprehensive Testing: 100+ unit tests ensuring production reliability
  • ๐Ÿ“ˆ Advanced Data Pipeline: Multi-format loading (CSV, JSON) with automatic preprocessing
  • โšก Enhanced Training: Early stopping, gradient clipping, multiple loss functions

๐Ÿ”ง Key Features

Core Architecture

  • โœ… Hybrid Neural Architecture: GNNs + Transformers + CFRNet for multi-modal causal inference
  • โœ… Gated Fusion Layer: Adaptive mixing of graph and textual representations
  • โœ… Flexible Graph Construction: Single-node, multi-node, temporal, and global graphs
  • โœ… Production-Ready: Comprehensive error handling, logging, and testing

Causal Inference Capabilities

  • โœ… Rigorous Assumption Testing: Automated validation of causal identification conditions
  • โœ… Multiple Identification Methods: Backdoor, IPW, doubly robust, with sensitivity analysis
  • โœ… Heterogeneous Treatment Effects: CATE estimation with 5+ methods (X/T/S-Learners, etc.)
  • โœ… Advanced Loss Functions: CFR, DragonNet, TARNet with representation balancing
  • โœ… Uncertainty Quantification: Bayesian, ensemble, conformal prediction approaches

Data Processing & Evaluation

  • โœ… Smart Data Loading: CSV, JSON, synthetic data with automatic graph construction
  • โœ… Comprehensive Evaluation: PEHE, ATE, policy value, calibration metrics
  • โœ… Memory Efficiency: Lazy loading, batch processing for large-scale datasets
  • โœ… Easy Configuration: JSON/YAML configs with validation and experiment tracking

๐Ÿ—๏ธ Architecture

 +-----------+     +-----------+
 |  GNN Emb  |     |  BERT Emb |
 +-----------+     +-----------+
        \             /
         \ Fusion Layer /
          \     /
         +-----------+
         |  Fused Rep |
         +-----------+
               |
           CFRNet
        /          \
   mu_0(x)       mu_1(x)

๐Ÿš€ Enhanced Causal Analysis Workflow

Complete Example with Assumption Testing & CATE Estimation

from cans import (
    CANSConfig, CANS, GCN, CANSRunner,
    create_sample_dataset, get_data_loaders,
    CausalAssumptionTester, CausalLossManager, 
    CATEManager, UncertaintyManager,
    advanced_counterfactual_analysis
)

# 1. Configuration with enhanced causal features
config = CANSConfig()
config.model.gnn_type = "GCN"
config.training.loss_type = "cfr"  # Causal loss function
config.data.graph_construction = "global"  # Multi-node graphs

# 2. Test causal assumptions BEFORE modeling
assumption_tester = CausalAssumptionTester()
results = assumption_tester.comprehensive_test(X, T, Y)
print(f"Causal assumptions valid: {results['causal_identification_valid']}")

# 3. Create datasets with enhanced graph construction
datasets = create_sample_dataset(n_samples=1000, config=config.data)
train_loader, val_loader, test_loader = get_data_loaders(datasets)

# 4. Setup model with causal loss functions
from transformers import BertModel
gnn = GCN(in_dim=64, hidden_dim=128, output_dim=256)
bert = BertModel.from_pretrained("distilbert-base-uncased")
model = CANS(gnn, bert, fusion_dim=256)

loss_manager = CausalLossManager("cfr", alpha=1.0, beta=0.5)

# 5. Train with causal-aware pipeline
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
runner = CANSRunner(model, optimizer, config)
history = runner.fit(train_loader, val_loader)

# 6. Multiple counterfactual identification methods
cf_results = advanced_counterfactual_analysis(
    model, test_loader, 
    methods=['backdoor', 'ipw', 'doubly_robust']
)

# 7. CATE estimation with multiple learners
cate_manager = CATEManager(method="x_learner")
cate_manager.fit(X, T, Y)
individual_effects = cate_manager.estimate_cate(X_test)

# 8. Uncertainty quantification
uncertainty_manager = UncertaintyManager(method="conformal")
uncertainty_manager.setup(model)
intervals = uncertainty_manager.estimate_uncertainty(test_loader)

print(f"ATE: {cf_results['backdoor']['ate']:.3f}")
print(f"Coverage: {intervals['coverage_rate']:.3f}")

๐Ÿš€ Quick Start

Installation

# Clone the repository
git clone https://github.com/rdmurugan/cans-framework.git
cd cans-framework

# Install dependencies
pip install -r cans/requirements.txt.rtf

Core Dependencies:

  • torch>=2.0.0
  • transformers>=4.38.0
  • torch-geometric>=2.3.0
  • scikit-learn>=1.3.0
  • pandas>=2.0.0

Basic Usage (30 seconds to results)

from cans.config import CANSConfig
from cans.utils.data import create_sample_dataset, get_data_loaders
from cans.models import CANS
from cans.models.gnn_modules import GCN
from cans.pipeline.runner import CANSRunner
from transformers import BertModel
import torch

# 1. Create configuration
config = CANSConfig()
config.training.epochs = 10

# 2. Load data (or create sample data)
datasets = create_sample_dataset(n_samples=1000, n_features=64)
train_loader, val_loader, test_loader = get_data_loaders(datasets, batch_size=32)

# 3. Create model
gnn = GCN(in_dim=64, hidden_dim=128, output_dim=256)
bert = BertModel.from_pretrained("bert-base-uncased")
model = CANS(gnn, bert, fusion_dim=256)

# 4. Train
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
runner = CANSRunner(model, optimizer, config)
history = runner.fit(train_loader, val_loader)

# 5. Evaluate
results = runner.evaluate(test_loader)
print(f"Test MSE: {results['mse']:.4f}")
print(f"Average Treatment Effect: {results['ate']:.4f}")

๐Ÿ“Š Usage Examples

Example 1: CSV Data with Real Causal Inference

from cans.utils.data import load_csv_dataset
from cans.config import CANSConfig, DataConfig

# Configure data processing
config = CANSConfig()
config.data.graph_construction = "knn"  # or "similarity" 
config.data.knn_k = 5
config.data.scale_node_features = True

# Load your CSV data
datasets = load_csv_dataset(
    csv_path="your_data.csv",
    text_column="review_text",        # Column with text data
    treatment_column="intervention",   # Binary treatment (0/1)
    outcome_column="conversion_rate",  # Continuous outcome  
    feature_columns=["age", "income", "education"],  # Numerical features
    config=config.data
)

train_dataset, val_dataset, test_dataset = datasets

# Check data quality
stats = train_dataset.get_statistics()
print(f"Treatment proportion: {stats['treatment_proportion']:.3f}")
print(f"Propensity overlap valid: {stats['propensity_overlap_valid']}")

Example 2: Advanced Configuration & Experiment Tracking

from cans.config import CANSConfig

# Create detailed configuration
config = CANSConfig()

# Model configuration
config.model.gnn_type = "GCN"
config.model.gnn_hidden_dim = 256
config.model.fusion_dim = 512
config.model.text_model = "distilbert-base-uncased"  # Faster BERT variant

# Training configuration  
config.training.learning_rate = 0.001
config.training.batch_size = 64
config.training.epochs = 50
config.training.early_stopping_patience = 10
config.training.gradient_clip_norm = 1.0
config.training.loss_type = "huber"  # Robust to outliers

# Experiment tracking
config.experiment.experiment_name = "healthcare_causal_analysis"
config.experiment.save_every_n_epochs = 5
config.experiment.log_level = "INFO"

# Save configuration for reproducibility
config.save("experiment_config.json")

# Later: load and use
loaded_config = CANSConfig.load("experiment_config.json")

Example 3: Counterfactual Analysis & Treatment Effects

from cans.utils.causal import simulate_counterfactual
import numpy as np

# After training your model...
runner = CANSRunner(model, optimizer, config)
runner.fit(train_loader, val_loader)

# Comprehensive evaluation
test_metrics = runner.evaluate(test_loader)
print("Performance Metrics:")
for metric, value in test_metrics.items():
    print(f"  {metric}: {value:.4f}")

# Counterfactual analysis
cf_control = simulate_counterfactual(model, test_loader, intervention=0)
cf_treatment = simulate_counterfactual(model, test_loader, intervention=1)

# Calculate causal effects
ate = np.mean(cf_treatment) - np.mean(cf_control)
print(f"\nCausal Analysis:")
print(f"Average Treatment Effect (ATE): {ate:.4f}")
print(f"Expected outcome under control: {np.mean(cf_control):.4f}")
print(f"Expected outcome under treatment: {np.mean(cf_treatment):.4f}")

# Individual treatment effects
individual_effects = np.array(cf_treatment) - np.array(cf_control)
print(f"Treatment effect std: {np.std(individual_effects):.4f}")
print(f"% benefiting from treatment: {(individual_effects > 0).mean()*100:.1f}%")

Example 4: Custom Data Pipeline

from cans.utils.preprocessing import DataPreprocessor, GraphBuilder
from cans.config import DataConfig
import pandas as pd

# Custom preprocessing pipeline
config = DataConfig()
config.graph_construction = "similarity"
config.similarity_threshold = 0.7
config.scale_node_features = True

preprocessor = DataPreprocessor(config)

# Process your DataFrame  
df = pd.read_csv("social_media_posts.csv")
dataset = preprocessor.process_tabular_data(
    data=df,
    text_column="post_content",
    treatment_column="fact_check_label",
    outcome_column="share_count",
    feature_columns=["user_followers", "post_length", "sentiment_score"],
    text_model="bert-base-uncased",
    max_text_length=256
)

# Split with custom ratios
train_ds, val_ds, test_ds = preprocessor.split_dataset(
    dataset, 
    train_size=0.7, 
    val_size=0.2, 
    test_size=0.1
)

๐Ÿงช Testing & Development

# Run all tests
pytest tests/ -v

# Run specific test categories  
pytest tests/test_models.py -v        # Model tests
pytest tests/test_validation.py -v    # Validation tests
pytest tests/test_pipeline.py -v      # Training pipeline tests

# Run with coverage
pytest tests/ --cov=cans --cov-report=html

# Run example scripts
python examples/enhanced_usage_example.py
python examples/enhanced_causal_analysis_example.py

๐Ÿ“ Framework Structure

cans-framework/
โ”œโ”€โ”€ cans/
โ”‚   โ”œโ”€โ”€ __init__.py              # Main imports
โ”‚   โ”œโ”€โ”€ config.py                # โœจ Configuration management
โ”‚   โ”œโ”€โ”€ exceptions.py            # โœจ Custom exceptions
โ”‚   โ”œโ”€โ”€ validation.py            # โœจ Data validation utilities
โ”‚   โ”œโ”€โ”€ models/
โ”‚   โ”‚   โ”œโ”€โ”€ cans.py             # Core CANS model (enhanced)
โ”‚   โ”‚   โ””โ”€โ”€ gnn_modules.py      # GNN implementations
โ”‚   โ”œโ”€โ”€ pipeline/
โ”‚   โ”‚   โ””โ”€โ”€ runner.py           # โœจ Enhanced training pipeline
โ”‚   โ””โ”€โ”€ utils/
โ”‚       โ”œโ”€โ”€ causal.py           # Counterfactual simulation
โ”‚       โ”œโ”€โ”€ data.py             # โœจ Enhanced data loading
โ”‚       โ”œโ”€โ”€ preprocessing.py     # โœจ Advanced preprocessing
โ”‚       โ”œโ”€โ”€ logging.py          # โœจ Structured logging
โ”‚       โ””โ”€โ”€ checkpointing.py    # โœจ Model checkpointing
โ”œโ”€โ”€ tests/                       # โœจ Comprehensive test suite
โ”œโ”€โ”€ examples/                    # Usage examples
โ””โ”€โ”€ CLAUDE.md                   # Development guide

โœจ = New/Enhanced in v2.0

๐ŸŽฏ Use Cases & Applications

Healthcare & Medical

# Analyze treatment effectiveness with patient records + clinical notes
datasets = load_csv_dataset(
    csv_path="patient_outcomes.csv",
    text_column="clinical_notes",
    treatment_column="medication_type", 
    outcome_column="recovery_score",
    feature_columns=["age", "bmi", "comorbidities"]
)

Marketing & A/B Testing

# Marketing campaign effectiveness with customer profiles + ad content
datasets = load_csv_dataset(
    csv_path="campaign_data.csv", 
    text_column="ad_content",
    treatment_column="campaign_variant",
    outcome_column="conversion_rate",
    feature_columns=["customer_ltv", "demographics", "behavior_score"]
)

Social Media & Content Moderation

# Impact of content moderation on engagement
datasets = load_csv_dataset(
    csv_path="posts_data.csv",
    text_column="post_text", 
    treatment_column="moderation_action",
    outcome_column="engagement_score",
    feature_columns=["user_followers", "post_length", "sentiment"]
)

๐Ÿ”ฌ Research & Methodology

CANS implements state-of-the-art causal inference techniques:

  • Counterfactual Regression Networks (CFRNet): Learn representations that minimize treatment assignment bias
  • Gated Fusion: Adaptively combine graph-structured and textual information
  • Balanced Representation: Minimize distributional differences between treatment groups
  • Propensity Score Validation: Automatic overlap checking for reliable causal estimates

Key Papers:

  • Shalit et al. "Estimating individual treatment effect: generalization bounds and algorithms" (ICML 2017)
  • Yao et al. "Representation learning for treatment effect estimation from observational data" (NeurIPS 2018)

๐Ÿš€ Performance & Scalability

  • Memory Efficient: Optimized batch processing and gradient checkpointing
  • GPU Acceleration: Full CUDA support with automatic device selection
  • Parallel Processing: Multi-core data loading and preprocessing
  • Production Ready: Comprehensive error handling and logging

Benchmarks (approximate, hardware-dependent):

  • Small: 1K samples, 32 features โ†’ ~30 sec training
  • Medium: 100K samples, 128 features โ†’ ~10 min training
  • Large: 1M+ samples โ†’ Scales with batch size and hardware

๐Ÿ“š Additional Resources

  • Documentation: See CLAUDE.md for detailed development guide
  • Examples: Check examples/ directory for complete workflows
  • Tests: tests/ contains 100+ unit tests demonstrating usage
  • Issues: Report bugs and feature requests on GitHub

๐Ÿค Contributing

Contributions welcome! Please:

  1. Fork the repository
  2. Create a feature branch (git checkout -b feature/amazing-feature)
  3. Add tests for new functionality
  4. Run tests: pytest tests/ -v
  5. Submit a pull request

Areas we'd love help with:

  • Additional GNN architectures (GraphSAGE, Graph Transformers)
  • More evaluation metrics for causal inference
  • Integration with popular ML platforms (MLflow, Weights & Biases)
  • Performance optimizations

๐Ÿ‘จโ€๐Ÿ”ฌ Authors

Durai Rajamanickam โ€“ @duraimuruganr reach out to durai@infinidatum.net

๐Ÿ“œ License

MIT License. Free to use, modify, and distribute with attribution.


Ready to get started? Try the 30-second quick start above, or dive into the detailed examples! ๐Ÿš€

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

cans_framework-3.0.0.tar.gz (71.0 kB view details)

Uploaded Source

Built Distribution

If you're not sure about the file name format, learn more about wheel file names.

cans_framework-3.0.0-py3-none-any.whl (62.4 kB view details)

Uploaded Python 3

File details

Details for the file cans_framework-3.0.0.tar.gz.

File metadata

  • Download URL: cans_framework-3.0.0.tar.gz
  • Upload date:
  • Size: 71.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.5

File hashes

Hashes for cans_framework-3.0.0.tar.gz
Algorithm Hash digest
SHA256 9ae0264ca17216919ed638db7b9ce2cb3480bb623d1695d34f6ef0d0f4d7b3c1
MD5 fcf651f85fd31255b03a45fef3ed28b9
BLAKE2b-256 f9469bc5a6c4b0da1776d8f112873bb45c012958269dd8ecdca8c594a3f75c12

See more details on using hashes here.

File details

Details for the file cans_framework-3.0.0-py3-none-any.whl.

File metadata

  • Download URL: cans_framework-3.0.0-py3-none-any.whl
  • Upload date:
  • Size: 62.4 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/6.1.0 CPython/3.13.5

File hashes

Hashes for cans_framework-3.0.0-py3-none-any.whl
Algorithm Hash digest
SHA256 8a988303edf46e58924e702c931a7c3dbfc3099ec92568617fe1e4f902cc18f6
MD5 12831f39ae07480827ff06fb256c27f4
BLAKE2b-256 944441966929e7a15a9decf817c8232bc34d3daa7bb12d70af863443eeaf7070

See more details on using hashes here.

Supported by

AWS Cloud computing and Security Sponsor Datadog Monitoring Depot Continuous Integration Fastly CDN Google Download Analytics Pingdom Monitoring Sentry Error logging StatusPage Status page