Introduction
EHRShot is a benchmark for few-shot evaluation on Electronic Health Records (EHR) data, covering multiple predictive tasks including operational outcomes, lab values, new diagnoses, and medical imaging findings.
This vignette demonstrates how to use the EHRShotDataset
class and BenchmarkEHRShot task to work with EHRShot data
in R.
Dataset Structure
The EHRShot dataset consists of several components:
- ehrshot.csv: Main events table with clinical codes
- splits/person_id_map.csv: Train/validation/test split assignments
- benchmark/[task_name]/labeled_patients.csv: Label files for each prediction task
Available Tasks
EHRShot supports multiple categories of prediction tasks:
Operational Outcomes (Binary Classification)
-
guo_los: Length of stay prediction -
guo_readmission: Hospital readmission prediction -
guo_icu: ICU admission prediction
Lab Values (Multiclass Classification)
-
lab_thrombocytopenia: Low platelet count severity -
lab_hyperkalemia: High potassium level severity -
lab_hypoglycemia: Low blood sugar severity -
lab_hyponatremia: Low sodium level severity -
lab_anemia: Anemia severity
New Diagnoses (Binary Classification)
-
new_hypertension: New hypertension diagnosis -
new_hyperlipidemia: New hyperlipidemia diagnosis -
new_pancan: New pancreatic cancer diagnosis -
new_celiac: New celiac disease diagnosis -
new_lupus: New lupus diagnosis -
new_acutemi: New acute myocardial infarction diagnosis
Basic Usage
Example 1: Binary Classification Task (Operational Outcome)
# Initialize dataset with a binary classification task
## Not run:
dataset <- EHRShotDataset$new(
root = "/path/to/ehrshot",
tables = c("ehrshot", "splits", "guo_los"),
dev = TRUE # Use dev mode for faster prototyping (limits to 1000 patients)
)
## End(Not run)
# Display dataset statistics
dataset$stats()
# Create task
task <- BenchmarkEHRShot$new(task = "guo_los")
# Generate samples
samples <- dataset$set_task(task = task, num_workers = 4)
# View sample structure
print(samples$samples[[1]])Example 2: Multiclass Classification Task (Lab Values)
# Initialize dataset with a multiclass task
dataset <- EHRShotDataset$new(
root = "/path/to/ehrshot",
tables = c("ehrshot", "splits", "lab_thrombocytopenia"),
dev = TRUE
)
# Create task
task <- BenchmarkEHRShot$new(task = "lab_thrombocytopenia")
# Generate samples
samples <- dataset$set_task(task = task, num_workers = 4)
# Check output schema
print(task$output_schema) # Should be: list(label = "multiclass")Example 3: Multilabel Classification Task (Medical Imaging)
# Initialize dataset with the CheXpert task
dataset <- EHRShotDataset$new(
root = "/path/to/ehrshot",
tables = c("ehrshot", "splits", "chexpert"),
dev = FALSE # Use full dataset
)
# Create task
task <- BenchmarkEHRShot$new(task = "chexpert")
# Generate samples
samples <- dataset$set_task(task = task, num_workers = 8)
# View a sample - labels will be a vector of positive indices
print(samples$samples[[1]]$label)Example 4: Filtering by OMOP Tables
You can filter clinical events by specific OMOP table types to focus on particular kinds of clinical data:
# Create task with OMOP table filtering
# Only use conditions and drug exposures as features
task <- BenchmarkEHRShot$new(
task = "new_hypertension",
omop_tables = c("condition_occurrence", "drug_exposure")
)
dataset <- EHRShotDataset$new(
root = "/path/to/ehrshot",
tables = c("ehrshot", "splits", "new_hypertension")
)
samples <- dataset$set_task(task = task)Working with Multiple Tasks
You can create datasets with multiple task labels:
# Initialize dataset with multiple tasks
dataset <- EHRShotDataset$new(
root = "/path/to/ehrshot",
tables = c(
"ehrshot", "splits",
"guo_los", "guo_readmission", "guo_icu",
"lab_thrombocytopenia", "new_hypertension"
),
dev = TRUE
)
# You can then create separate samples for each task
task1 <- BenchmarkEHRShot$new(task = "guo_los")
samples1 <- dataset$set_task(task = task1)
task2 <- BenchmarkEHRShot$new(task = "new_hypertension")
samples2 <- dataset$set_task(task = task2)Sample Structure
Each sample generated by the BenchmarkEHRShot task
contains:
- feature: A character vector of clinical codes from the ehrshot events
-
label: The prediction target (format depends on
task type)
- Binary: 0 or 1
- Multiclass: Integer class label
- Multilabel: Vector of positive class indices
- split: Data split assignment (“train”, “val”, or “test”)
Performance Tips
-
Use dev mode (
dev = TRUE) during development to work with a subset of 1000 patients -
Enable parallel processing with
num_workersparameter when generating samples -
Use caching by specifying
cache_dirto avoid regenerating samples - Filter by OMOP tables to reduce feature space when appropriate
# Example with performance optimizations
dataset <- EHRShotDataset$new(
root = "/path/to/ehrshot",
tables = c("ehrshot", "splits", "guo_los"),
dev = FALSE
)
task <- BenchmarkEHRShot$new(
task = "guo_los",
omop_tables = c("condition_occurrence", "procedure_occurrence", "drug_exposure")
)
samples <- dataset$set_task(
task = task,
num_workers = 8, # Use 8 parallel workers
cache_dir = "./cache" # Cache samples for reuse
)Data Splits
The EHRShot dataset includes predefined train/validation/test splits. To work with specific splits:
# After generating samples, filter by split
train_samples <- Filter(function(s) s$split == "train", samples$samples)
val_samples <- Filter(function(s) s$split == "val", samples$samples)
test_samples <- Filter(function(s) s$split == "test", samples$samples)
cat(sprintf("Train: %d, Val: %d, Test: %d\n",
length(train_samples),
length(val_samples),
length(test_samples)))Additional Resources
- EHRShot Website: https://som-shahlab.github.io/ehrshot-website/
- RHealth Documentation: [link to your package docs]
- For questions and issues: [link to your issue tracker]
References
For more information about the EHRShot benchmark, please refer to:
- EHRShot: An EHR Benchmark for Few-Shot Evaluation of Foundation Models
- https://som-shahlab.github.io/ehrshot-website/