🚀 Getting Started

trustcv provides medical-specific cross-validation methods designed for clinical machine learning applications.

Installation

pip install trustcv

# Or from source
git clone https://github.com/ki-smile/trustcv.git
cd trustcv
pip install -e .

Basic Usage

from trustcv.splitters.iid import StratifiedKFoldMedical
from trustcv.splitters.grouped import GroupKFoldMedical
from sklearn.model_selection import cross_val_score
from sklearn.ensemble import RandomForestClassifier

# Basic stratified cross-validation
cv = StratifiedKFoldMedical(n_splits=5, shuffle=True, random_state=42)
scores = cross_val_score(RandomForestClassifier(), X, y, cv=cv)

# Patient-grouped cross-validation
cv_grouped = GroupKFoldMedical(n_splits=5)
scores_grouped = cross_val_score(RandomForestClassifier(), X, y, 
                                groups=patient_ids, cv=cv_grouped)

🎲 I.I.D. Methods (9)

Cross-validation methods for independent and identically distributed data.

HoldOut

Simple train-test split for initial model evaluation.

HoldOut(test_size=0.2, random_state=None, stratify=True)

KFoldMedical

Standard K-Fold cross-validation.

KFoldMedical(n_splits=5, shuffle=False, random_state=None)

StratifiedKFoldMedical

Stratified K-Fold preserving class distribution in each fold.

StratifiedKFoldMedical(n_splits=5, shuffle=False, random_state=None)

RepeatedKFold

Repeated K-Fold for more stable estimates.

RepeatedKFold(n_splits=5, n_repeats=10, random_state=None)

LOOCV

Leave-One-Out Cross-Validation for small datasets.

LOOCV()

LPOCV

Leave-p-Out Cross-Validation for exhaustive testing.

LPOCV(p=2)

BootstrapValidation

Bootstrap validation with .632/.632+ estimators.

BootstrapValidation(n_iterations=100, estimator='.632', random_state=None)

MonteCarloCV

Random sub-sampling cross-validation.

MonteCarloCV(n_iterations=100, test_size=0.2, random_state=None)

NestedCV

Nested cross-validation for unbiased hyperparameter tuning.

NestedCV(outer_cv=None, inner_cv=None)

👥 Grouped Methods (8)

Cross-validation methods for grouped medical data (patients, sites, etc.).

GroupKFoldMedical

Patient-level grouping to prevent data leakage.

GroupKFoldMedical(n_splits=5, shuffle=False)

StratifiedGroupKFold

Combines grouping with stratification for balanced grouped splits.

StratifiedGroupKFold(n_splits=5, shuffle=False, random_state=None)

LeaveOneGroupOut

Leave one group (patient/hospital) out for testing.

LeaveOneGroupOut()

LeavePGroupsOut

Leave p groups out for testing.

LeavePGroupsOut(n_groups=2)

RepeatedGroupKFold

Repeated grouped K-Fold for stable estimates.

RepeatedGroupKFold(n_splits=5, n_repeats=10, random_state=None)

HierarchicalGroupKFold

Handles nested hierarchical structures (Hospital→Department→Patient).

HierarchicalGroupKFold(n_splits=5)

MultilevelCV

Cross-validation across multiple hierarchy levels.

MultilevelCV(n_splits=5)

NestedGroupedCV

Nested CV maintaining group structure in both loops.

NestedGroupedCV(outer_cv=None, inner_cv=None)

⏰ Temporal Methods (8)

Cross-validation methods for time-series and temporal medical data.

TimeSeriesSplit

Time-aware splitting for temporal validation.

TimeSeriesSplit(n_splits=5, test_size=None, gap=0)

BlockedTimeSeries

Block-based time series splitting.

BlockedTimeSeries(n_splits=5)

RollingWindowCV

Fixed-size rolling window cross-validation.

RollingWindowCV(window_size=100, forecast_horizon=10, gap=0)

ExpandingWindowCV

Expanding window cross-validation.

ExpandingWindowCV(min_train_size=50, forecast_horizon=10)

PurgedKFoldCV

K-Fold with purged samples to prevent temporal leakage.

PurgedKFoldCV(n_splits=5, purge_gap=1)

CombinatorialPurgedCV

Multiple train/test combinations with purging.

CombinatorialPurgedCV(n_splits=5, n_test_groups=2, purge_gap=0)

PurgedGroupTimeSeriesSplit

Combines temporal, grouping, and purging.

PurgedGroupTimeSeriesSplit(n_splits=5, purge_gap=0)

NestedTemporalCV

Nested CV respecting temporal order in both loops.

NestedTemporalCV(outer_cv=None, inner_cv=None)

🌍 Spatial Methods (4)

Cross-validation methods for geographic and spatial medical data.

SpatialBlockCV

Spatial block cross-validation for geographic data.

SpatialBlockCV(n_splits=5, n_blocks=None)

BufferedSpatialCV

Spatial CV with buffer zones to reduce autocorrelation.

BufferedSpatialCV(n_splits=5, buffer_size=1.0)

SpatiotemporalBlockCV

Combined spatial and temporal blocking.

SpatiotemporalBlockCV(n_spatial_blocks=3, n_temporal_blocks=3)

EnvironmentalHealthCV

Specialized CV for environmental health studies.

EnvironmentalHealthCV(spatial_blocks=4, temporal_strategy='seasonal')

🔍 Validators & Checkers

High-level validators and data quality checkers.

TrustCVValidator

Main entry point for cross-validation with automatic leakage detection.

TrustCVValidator(cv=None, check_leakage=True)

Example:

from trustcv import TrustCVValidator
from trustcv.splitters import GroupKFold

validator = TrustCVValidator(cv=GroupKFold(n_splits=5), check_leakage=True)
results = validator.fit_validate(model, X, y, groups=patient_ids)

DataLeakageChecker

Automatic detection of 6 types of data leakage.

DataLeakageChecker()

BalanceChecker

Check class and group balance in CV splits.

BalanceChecker()

💡 Complete Examples

Multi-Site Clinical Trial

from trustcv.splitters.grouped import GroupKFoldMedical
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.metrics import roc_auc_score
import numpy as np

# Load your data
# X: features, y: outcomes, site_ids: hospital identifiers

# Initialize grouped cross-validation
cv = GroupKFoldMedical(n_splits=8)  # One fold per hospital
model = GradientBoostingClassifier()

scores = []
for train_idx, test_idx in cv.split(X, y, groups=site_ids):
    # Train on 7 hospitals, test on 1 hospital
    model.fit(X[train_idx], y[train_idx])
    y_pred = model.predict_proba(X[test_idx])[:, 1]
    score = roc_auc_score(y[test_idx], y_pred)
    scores.append(score)

print(f"Cross-hospital AUC: {np.mean(scores):.3f} ± {np.std(scores):.3f}")

ICU Time Series Prediction

from trustcv.splitters.temporal import TimeSeriesSplit
from sklearn.ensemble import RandomForestRegressor

# Temporal validation for ICU monitoring
cv = TimeSeriesSplit(n_splits=5, gap=4)  # 4-hour prediction gap
model = RandomForestRegressor()

for train_idx, test_idx in cv.split(X_timeseries):
    # Always train on past, predict future
    model.fit(X_timeseries[train_idx], y_timeseries[train_idx])
    predictions = model.predict(X_timeseries[test_idx])
    # Evaluate temporal prediction performance