Framework Integration Guide

Use trustcv with any ML framework — the same specialized CV strategies work across scikit-learn, PyTorch, TensorFlow, MONAI, and JAX.

scikit-learn PyTorch TensorFlow MONAI JAX

Overview

TrustCV is designed from the ground up to be framework-agnostic. Whether you are training a simple random forest with scikit-learn or a 3D U-Net with MONAI, the same cross-validation splitters and runners handle data partitioning, leakage prevention, and result aggregation. This is particularly important for:

Key innovation: trustcv provides 29 specialized CV methods, many of which are not available in scikit-learn. All of them work with every supported framework through the UniversalCVRunner interface.

Custom CV Methods Not in scikit-learn

These methods address real-world data structures that standard K-fold cannot handle correctly.

Time Series

PurgedKFoldCV

Problem: Standard K-fold causes look-ahead bias in temporal data.
Solution: Adds configurable temporal gaps between train and test sets, preventing information leakage across time boundaries.

Hierarchical

HierarchicalGroupKFold

Problem: Patients nested within hospitals, hospitals within regions.
Solution: Respects multi-level grouping structure so that entire clusters stay together during splitting.

Spatial

SpatialBlockCV

Problem: Spatial autocorrelation violates the independence assumption.
Solution: Creates spatially separated train/test blocks with automatic block-size determination.

Spatial

BufferedSpatialCV

Problem: Nearby samples can still leak information across block borders.
Solution: Adds configurable buffer zones between spatial partitions to prevent spillover effects.

Advanced Time Series

CombinatorialPurgedCV

Problem: Need to validate on multiple future periods simultaneously.
Solution: Combinatorial approach that tests on several future windows while purging temporal overlap.

Framework Examples

Select a framework below to see a complete integration example.

The simplest integration — pass any sklearn estimator directly to the runner.

from trustcv import StratifiedKFold, UniversalCVRunner
from sklearn.ensemble import RandomForestClassifier

cv = StratifiedKFold(n_splits=5)
runner = UniversalCVRunner(cv_splitter=cv)
results = runner.run(
    model=RandomForestClassifier(),
    data=(X, y)
)

print(f"Mean score: {results.mean_score:.4f}")

Clinical risk prediction with a tabular deep network, patient-grouped stratified CV, and early stopping.

import torch
import torch.nn as nn
from trustcv import TorchCVRunner, StratifiedGroupKFold
from trustcv import EarlyStopping, ProgressLogger

class ClinicalRiskNet(nn.Module):
    def __init__(self, input_dim, hidden_dims=[128, 64, 32], dropout=0.3):
        super().__init__()
        layers = []
        prev_dim = input_dim
        for hidden_dim in hidden_dims:
            layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.BatchNorm1d(hidden_dim),
                nn.ReLU(),
                nn.Dropout(dropout)
            ])
            prev_dim = hidden_dim
        layers.append(nn.Linear(prev_dim, 2))
        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)

cv = StratifiedGroupKFold(n_splits=5)

runner = TorchCVRunner(
    model_fn=lambda: ClinicalRiskNet(n_features),
    cv_splitter=cv
)

results = runner.run(
    dataset=(X, y),
    epochs=50,
    optimizer_fn=lambda m: torch.optim.AdamW(
        m.parameters(), lr=0.001, weight_decay=1e-5
    ),
    loss_fn=nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0])),
    callbacks=[
        EarlyStopping(monitor='val_loss', patience=10, restore_best=True),
        ProgressLogger(log_file='training_log.json')
    ],
    groups=patient_ids
)

Medical image classification with transfer learning and blocked time-series CV.

import tensorflow as tf
from tensorflow import keras
from trustcv import KerasCVRunner, BlockedTimeSeries

def create_transfer_model():
    base_model = keras.applications.ResNet50(
        weights='imagenet', include_top=False,
        input_shape=(224, 224, 3)
    )
    base_model.trainable = False
    return keras.Sequential([
        base_model,
        keras.layers.GlobalAveragePooling2D(),
        keras.layers.Dense(128, activation='relu'),
        keras.layers.Dropout(0.5),
        keras.layers.Dense(4, activation='softmax')
    ])

cv = BlockedTimeSeries(n_splits=5, block_size=200)

runner = KerasCVRunner(
    model_fn=create_transfer_model,
    cv_splitter=cv,
    compile_kwargs={
        'optimizer': keras.optimizers.Adam(1e-4),
        'loss': 'sparse_categorical_crossentropy',
        'metrics': ['accuracy', keras.metrics.AUC(name='auc')]
    }
)

results = runner.run(
    X=X, y=y,
    epochs=30, batch_size=16,
    callbacks=[
        keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5),
        keras.callbacks.EarlyStopping(patience=5)
    ],
    groups=timestamps
)

print(f"Mean AUC: {results.mean_score['val_auc']:.4f}")

Brain tumor segmentation with 3D U-Net, patient-grouped CV, and MONAI transforms.

from trustcv import MONAICVRunner, GroupKFoldMedical
from trustcv.frameworks.monai import MONAIAdapter
from monai.networks.nets import UNet
from monai.losses import DiceLoss
from monai.metrics import DiceMetric

# Patient-grouped CV (no patient in both train & test)
cv_splitter = GroupKFoldMedical(n_splits=5)

def create_unet():
    return UNet(
        spatial_dims=3, in_channels=1, out_channels=2,
        channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2),
        num_res_units=2
    )

runner = MONAICVRunner(
    model_fn=create_unet,
    cv_splitter=cv_splitter,
    adapter=MONAIAdapter(
        batch_size=2, cache_rate=0.5,
        roi_size=(96, 96, 96), device='cuda'
    )
)

results = runner.run(
    data_dicts=data_dicts,
    epochs=100,
    train_transforms=train_transforms,
    val_transforms=val_transforms,
    optimizer_fn=lambda m: torch.optim.AdamW(m.parameters(), lr=1e-4),
    loss_fn=DiceLoss(to_onehot_y=True, softmax=True),
    metrics=[DiceMetric(include_background=False, reduction="mean")],
    callbacks=[
        EarlyStopping(monitor='val_dice', mode='max', patience=10),
        ModelCheckpoint(filepath='best_model_fold_{fold}.pth', monitor='val_dice')
    ],
    groups=patient_ids
)

print(f"Mean Dice: {results.mean_score['val_dice']:.4f}")

Flax MLP with patient-grouped stratified CV and JIT compilation.

import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
from trustcv.frameworks.jax import JAXAdapter, JAXCVRunner
from trustcv.splitters import StratifiedGroupKFold

class MedicalMLP(nn.Module):
    hidden_dim: int = 64
    n_classes: int = 2

    @nn.compact
    def __call__(self, x, training: bool = True):
        x = nn.Dense(self.hidden_dim)(x)
        x = nn.relu(x)
        x = nn.Dropout(rate=0.1, deterministic=not training)(x)
        x = nn.Dense(self.hidden_dim // 2)(x)
        x = nn.relu(x)
        x = nn.Dense(self.n_classes)(x)
        return x

runner = JAXCVRunner(
    model_fn=lambda: MedicalMLP(hidden_dim=64, n_classes=2),
    cv_splitter=StratifiedGroupKFold(n_splits=5),
    adapter=JAXAdapter(batch_size=32, seed=42, use_jit=True)
)

results = runner.run(
    X, y,
    epochs=20,
    groups=patient_ids,
    optimizer=optax.adam(1e-3)
)

print(results.summary())

Installation by Framework

Install only the extras you need, or grab everything at once.

Framework Install Command Notes
Core (scikit-learn) pip install trustcv Includes all 29 CV splitters and sklearn runners
PyTorch pip install trustcv[torch] Adds TorchCVRunner and DataLoader helpers
TensorFlow / Keras pip install trustcv[tensorflow] Adds KerasCVRunner and tf.data integration
MONAI pip install trustcv[monai] Adds MONAICVRunner, 3D transforms, CacheDataset
JAX / Flax pip install trustcv[jax] Adds JAXCVRunner and JIT-compiled training loops
All frameworks pip install trustcv[all] Everything included

Best Practices