Use trustcv with any ML framework — the same specialized CV strategies work across scikit-learn, PyTorch, TensorFlow, MONAI, and JAX.
TrustCV is designed from the ground up to be framework-agnostic. Whether you are training a simple random forest with scikit-learn or a 3D U-Net with MONAI, the same cross-validation splitters and runners handle data partitioning, leakage prevention, and result aggregation. This is particularly important for:
Key innovation: trustcv provides 29 specialized CV methods, many of which are not available in scikit-learn. All of them work with every supported framework through the UniversalCVRunner interface.
These methods address real-world data structures that standard K-fold cannot handle correctly.
Problem: Standard K-fold causes look-ahead bias in temporal data.
Solution: Adds configurable temporal gaps between train and test sets, preventing information leakage across time boundaries.
Problem: Patients nested within hospitals, hospitals within regions.
Solution: Respects multi-level grouping structure so that entire clusters stay together during splitting.
Problem: Spatial autocorrelation violates the independence assumption.
Solution: Creates spatially separated train/test blocks with automatic block-size determination.
Problem: Nearby samples can still leak information across block borders.
Solution: Adds configurable buffer zones between spatial partitions to prevent spillover effects.
Problem: Need to validate on multiple future periods simultaneously.
Solution: Combinatorial approach that tests on several future windows while purging temporal overlap.
Select a framework below to see a complete integration example.
The simplest integration — pass any sklearn estimator directly to the runner.
from trustcv import StratifiedKFold, UniversalCVRunner
from sklearn.ensemble import RandomForestClassifier
cv = StratifiedKFold(n_splits=5)
runner = UniversalCVRunner(cv_splitter=cv)
results = runner.run(
model=RandomForestClassifier(),
data=(X, y)
)
print(f"Mean score: {results.mean_score:.4f}")
Clinical risk prediction with a tabular deep network, patient-grouped stratified CV, and early stopping.
import torch
import torch.nn as nn
from trustcv import TorchCVRunner, StratifiedGroupKFold
from trustcv import EarlyStopping, ProgressLogger
class ClinicalRiskNet(nn.Module):
def __init__(self, input_dim, hidden_dims=[128, 64, 32], dropout=0.3):
super().__init__()
layers = []
prev_dim = input_dim
for hidden_dim in hidden_dims:
layers.extend([
nn.Linear(prev_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(),
nn.Dropout(dropout)
])
prev_dim = hidden_dim
layers.append(nn.Linear(prev_dim, 2))
self.network = nn.Sequential(*layers)
def forward(self, x):
return self.network(x)
cv = StratifiedGroupKFold(n_splits=5)
runner = TorchCVRunner(
model_fn=lambda: ClinicalRiskNet(n_features),
cv_splitter=cv
)
results = runner.run(
dataset=(X, y),
epochs=50,
optimizer_fn=lambda m: torch.optim.AdamW(
m.parameters(), lr=0.001, weight_decay=1e-5
),
loss_fn=nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0])),
callbacks=[
EarlyStopping(monitor='val_loss', patience=10, restore_best=True),
ProgressLogger(log_file='training_log.json')
],
groups=patient_ids
)
Medical image classification with transfer learning and blocked time-series CV.
import tensorflow as tf
from tensorflow import keras
from trustcv import KerasCVRunner, BlockedTimeSeries
def create_transfer_model():
base_model = keras.applications.ResNet50(
weights='imagenet', include_top=False,
input_shape=(224, 224, 3)
)
base_model.trainable = False
return keras.Sequential([
base_model,
keras.layers.GlobalAveragePooling2D(),
keras.layers.Dense(128, activation='relu'),
keras.layers.Dropout(0.5),
keras.layers.Dense(4, activation='softmax')
])
cv = BlockedTimeSeries(n_splits=5, block_size=200)
runner = KerasCVRunner(
model_fn=create_transfer_model,
cv_splitter=cv,
compile_kwargs={
'optimizer': keras.optimizers.Adam(1e-4),
'loss': 'sparse_categorical_crossentropy',
'metrics': ['accuracy', keras.metrics.AUC(name='auc')]
}
)
results = runner.run(
X=X, y=y,
epochs=30, batch_size=16,
callbacks=[
keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5),
keras.callbacks.EarlyStopping(patience=5)
],
groups=timestamps
)
print(f"Mean AUC: {results.mean_score['val_auc']:.4f}")
Brain tumor segmentation with 3D U-Net, patient-grouped CV, and MONAI transforms.
from trustcv import MONAICVRunner, GroupKFoldMedical
from trustcv.frameworks.monai import MONAIAdapter
from monai.networks.nets import UNet
from monai.losses import DiceLoss
from monai.metrics import DiceMetric
# Patient-grouped CV (no patient in both train & test)
cv_splitter = GroupKFoldMedical(n_splits=5)
def create_unet():
return UNet(
spatial_dims=3, in_channels=1, out_channels=2,
channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2),
num_res_units=2
)
runner = MONAICVRunner(
model_fn=create_unet,
cv_splitter=cv_splitter,
adapter=MONAIAdapter(
batch_size=2, cache_rate=0.5,
roi_size=(96, 96, 96), device='cuda'
)
)
results = runner.run(
data_dicts=data_dicts,
epochs=100,
train_transforms=train_transforms,
val_transforms=val_transforms,
optimizer_fn=lambda m: torch.optim.AdamW(m.parameters(), lr=1e-4),
loss_fn=DiceLoss(to_onehot_y=True, softmax=True),
metrics=[DiceMetric(include_background=False, reduction="mean")],
callbacks=[
EarlyStopping(monitor='val_dice', mode='max', patience=10),
ModelCheckpoint(filepath='best_model_fold_{fold}.pth', monitor='val_dice')
],
groups=patient_ids
)
print(f"Mean Dice: {results.mean_score['val_dice']:.4f}")
Flax MLP with patient-grouped stratified CV and JIT compilation.
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
from trustcv.frameworks.jax import JAXAdapter, JAXCVRunner
from trustcv.splitters import StratifiedGroupKFold
class MedicalMLP(nn.Module):
hidden_dim: int = 64
n_classes: int = 2
@nn.compact
def __call__(self, x, training: bool = True):
x = nn.Dense(self.hidden_dim)(x)
x = nn.relu(x)
x = nn.Dropout(rate=0.1, deterministic=not training)(x)
x = nn.Dense(self.hidden_dim // 2)(x)
x = nn.relu(x)
x = nn.Dense(self.n_classes)(x)
return x
runner = JAXCVRunner(
model_fn=lambda: MedicalMLP(hidden_dim=64, n_classes=2),
cv_splitter=StratifiedGroupKFold(n_splits=5),
adapter=JAXAdapter(batch_size=32, seed=42, use_jit=True)
)
results = runner.run(
X, y,
epochs=20,
groups=patient_ids,
optimizer=optax.adam(1e-3)
)
print(results.summary())
Install only the extras you need, or grab everything at once.
| Framework | Install Command | Notes |
|---|---|---|
| Core (scikit-learn) | pip install trustcv |
Includes all 29 CV splitters and sklearn runners |
| PyTorch | pip install trustcv[torch] |
Adds TorchCVRunner and DataLoader helpers |
| TensorFlow / Keras | pip install trustcv[tensorflow] |
Adds KerasCVRunner and tf.data integration |
| MONAI | pip install trustcv[monai] |
Adds MONAICVRunner, 3D transforms, CacheDataset |
| JAX / Flax | pip install trustcv[jax] |
Adds JAXCVRunner and JIT-compiled training loops |
| All frameworks | pip install trustcv[all] |
Everything included |
DataLeakageChecker to verify that no patient, temporal, or spatial overlap exists between training and test folds before trusting your results.
RegulatoryDocumentationLogger as a callback to produce structured outputs mapping to FDA and CE MDR requirements. See the Report Generator for details.
UniversalCVRunner + a basic splitter. Only switch to framework-specific runners (TorchCVRunner, KerasCVRunner, etc.) when you need advanced features like custom DataLoaders, callbacks, or JIT compilation.
results.indices and confirm zero overlap between train and test patient IDs. This is especially critical for multi-center clinical trials.