Skip to content

Commit

Permalink
Store the RIRA model names in the seurat object misc slot (#123)
Browse files Browse the repository at this point in the history
* Store the RIRA model names in the seurat object misc slot

* Add test
  • Loading branch information
bbimber authored Dec 19, 2024
1 parent b0564fa commit 4de8cb2
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
15 changes: 12 additions & 3 deletions R/CellTypist.R
Original file line number Diff line number Diff line change
Expand Up @@ -436,8 +436,9 @@ TrainCellTypist <- function(seuratObj, labelField, modelFile, minCellsPerClass =
#'
#' @export
Classify_TNK <- function(seuratObj, assayName = Seurat::DefaultAssay(seuratObj), columnPrefix = 'RIRA_TNK_v2.', maxAllowableClasses = 6, minFractionToInclude = 0.01, minCellsToRun = 200, maxBatchSize = 600000, retainProbabilityMatrix = FALSE) {
modelName <- "RIRA_TNK_v2"
seuratObj <- RunCellTypist(seuratObj = seuratObj,
modelName = "RIRA_TNK_v2",
modelName = modelName,
# These are optimized for this model:
pThreshold = 0.5, minProp = 0, useMajorityVoting = FALSE, mode = "prob_match",

Expand All @@ -455,6 +456,8 @@ Classify_TNK <- function(seuratObj, assayName = Seurat::DefaultAssay(seuratObj),
seuratObj$RIRA_TNK_v2.cellclass[seuratObj$RIRA_Immune_v2.majority_voting != 'T_NK'] <- 'Other'
}

seuratObj@misc$RIRA_TNK_Model <- modelName

return(seuratObj)
}

Expand All @@ -473,8 +476,9 @@ Classify_TNK <- function(seuratObj, assayName = Seurat::DefaultAssay(seuratObj),
#'
#' @export
Classify_Myeloid <- function(seuratObj, assayName = Seurat::DefaultAssay(seuratObj), columnPrefix = 'RIRA_Myeloid_v3.', maxAllowableClasses = 6, minFractionToInclude = 0.01, minCellsToRun = 200, maxBatchSize = 600000, retainProbabilityMatrix = FALSE) {
modelName <- "RIRA_FineScope_Myeloid_v3"
seuratObj <- RunCellTypist(seuratObj = seuratObj,
modelName = "RIRA_FineScope_Myeloid_v3",
modelName = modelName,
# These are optimized for this model:
pThreshold = 0.5, minProp = 0, useMajorityVoting = FALSE, mode = "prob_match",

Expand All @@ -498,6 +502,8 @@ Classify_Myeloid <- function(seuratObj, assayName = Seurat::DefaultAssay(seuratO
vect[seuratObj@meta.data[[fn2]] %in% c('DC', 'Mature DC')] <- 'DC'
seuratObj[[fn]] <- as.factor(vect)

seuratObj@misc$RIRA_Myeloid_Model <- modelName

return(seuratObj)
}

Expand All @@ -517,6 +523,7 @@ Classify_Myeloid <- function(seuratObj, assayName = Seurat::DefaultAssay(seuratO
#'
#' @export
Classify_ImmuneCells <- function(seuratObj, assayName = Seurat::DefaultAssay(seuratObj), columnPrefix = 'RIRA_Immune_v2.', maxAllowableClasses = 6, minFractionToInclude = 0.01, minCellsToRun = 200, maxBatchSize = 600000, retainProbabilityMatrix = FALSE, filterDisallowedClasses = TRUE) {
modelName <- 'RIRA_Immune_v2'
if ('RIRA_Immune_v1.cellclass' %in% names(seuratObj@meta.data)) {
print('Dropping legacy RIRA_Immune_v1 columns')
toDrop <- grep(names(seuratObj@meta.data), pattern = 'RIRA_Immune_v1', value = TRUE)
Expand All @@ -526,7 +533,7 @@ Classify_ImmuneCells <- function(seuratObj, assayName = Seurat::DefaultAssay(seu
}

seuratObj <- RunCellTypist(seuratObj = seuratObj,
modelName = 'RIRA_Immune_v2',
modelName = modelName,

# These are optimized for this model:
minProp = 0.5, useMajorityVoting = TRUE, mode = "prob_match",
Expand Down Expand Up @@ -569,6 +576,8 @@ Classify_ImmuneCells <- function(seuratObj, assayName = Seurat::DefaultAssay(seu
)
)

seuratObj@misc$RIRA_Immune_Model <- modelName

return(seuratObj)
}

Expand Down
3 changes: 3 additions & 0 deletions tests/testthat/test-celltypist.R
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ test_that("celltypist runs for RIRA models", {
seuratObj <- Classify_TNK(seuratObj, retainProbabilityMatrix = TRUE)
print(table(seuratObj$RIRA_TNK_v2.cellclass))

expect_equal('RIRA_TNK_v2', seuratObj@misc$RIRA_TNK_Model)
expect_equal(4, length(unique(seuratObj$RIRA_TNK_v2.cellclass)), info = 'using RIRA T_NK', tolerance = 1)
expect_equal(221, unname(table(seuratObj$RIRA_TNK_v2.cellclass)['CD4+ T Cells']), tolerance = 1)
expect_equal(1028, unname(table(seuratObj$RIRA_TNK_v2.cellclass)['CD8+ T Cells']), tolerance = 1)
Expand All @@ -93,6 +94,7 @@ test_that("celltypist runs for RIRA models", {
print(table(seuratObj$RIRA_Myeloid_v3.cellclass))
print(table(seuratObj$RIRA_Myeloid_v3.coarseclass))

expect_equal('RIRA_FineScope_Myeloid_v3', seuratObj@misc$RIRA_Myeloid_Model)
expect_equal(5, length(unique(seuratObj$RIRA_Myeloid_v3.cellclass)), info = 'using RIRA Myeloid')
expect_equal(32, unname(table(seuratObj$RIRA_Myeloid_v3.cellclass)['DC']), tolerance = 1)
expect_equal(32, unname(table(seuratObj$RIRA_Myeloid_v3.coarseclass)['DC']), tolerance = 1)
Expand All @@ -108,6 +110,7 @@ test_that("FilterDisallowedClasses works as expected", {
print('RIRA_Immune_v2.cellclass:')
print(table(seuratObj$RIRA_Immune_v2.cellclass))

expect_equal('RIRA_Immune_v2', seuratObj@misc$RIRA_Immune_Model)
expect_equal(256, sum(seuratObj$RIRA_Immune_v2.cellclass == 'Bcell', na.rm = T), tolerance = 1)
expect_equal(571, sum(seuratObj$RIRA_Immune_v2.cellclass == 'Myeloid', na.rm = T))
expect_equal(1303, sum(seuratObj$RIRA_Immune_v2.cellclass == 'T_NK', na.rm = T))
Expand Down

0 comments on commit 4de8cb2

Please sign in to comment.