Skip to content

Commit

Permalink
We have restored the parameter values to their previous values.
Browse files Browse the repository at this point in the history
  • Loading branch information
MatsuuraKentaro committed Jan 1, 2025
1 parent f767d04 commit a1c9dfc
Show file tree
Hide file tree
Showing 7 changed files with 21 additions and 15 deletions.
4 changes: 2 additions & 2 deletions R/rl_config_set.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#' @param cores A positive integer value. Number of CPU cores used for learning.
#' @param gamma A positive numeric value. Discount factor of the Markov decision
#' process. Default is 1.0 (not discount).
#' @param lr A positive numeric value. Learning rate (default 3e-5). You can set
#' @param lr A positive numeric value. Learning rate (default 5e-5). You can set
#' a learning schedule instead of a learning rate.
#' @param train_batch_size A positive integer value. Training batch size.
#' Deprecated on the new API stack.
Expand Down Expand Up @@ -41,7 +41,7 @@ rl_config_set <- function(iter = 1000L,
save_every_iter = NULL,
cores = 4L,
# Common settings
gamma = 1.0, lr = 3e-5,
gamma = 1.0, lr = 5e-5,
train_batch_size = 10000L, model = rl_dnn_config(),
# PPO specific settings
sgd_minibatch_size = 200L, num_sgd_iter = 20L,
Expand Down
2 changes: 1 addition & 1 deletion R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ compute_state <- function(current_dose, J, data_Ns, data_DLTs, N_total) {
stopifnot(length(data_Ns) == length(data_DLTs))
stopifnot(all(data_Ns >= data_DLTs))

is_final <- ifelse(sum(data_Ns) == N_total, 1.0, 0.0)
is_final <- ifelse(sum(data_Ns) == N_total, 0.2, 0.1)

state <- as.array(c(
(current_dose - 1) / J,
Expand Down
8 changes: 5 additions & 3 deletions README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ escalation_rule <- learn_escalation_rule(
escalation_rule
#> <EscalationRule>
#> dir: escalation_rules/20241231_130626
#> created at: 2024-12-31 14:23:31
#> dir: escalation_rules/20250101_162633
#> created at: 2025-01-01 17:43:23
#> call:
#> learn_escalation_rule(J = 6, target = 0.25, epsilon = 0.04, delta = 0.1,
#> N_total = 36, N_cohort = 3, seed = 123, rl_config = rl_config_set(iter = 1000))
Expand Down Expand Up @@ -141,12 +141,14 @@ library(dplyr)
MTD_true <- list("MTD_6", c("MTD_3", "MTD_4"), "no_MTD", "MTD_4")
d_sim |>
d_res <- d_sim |>
filter(cohortID == max(cohortID), .by = c(scenarioID, simID)) |>
rowwise() |>
mutate(correct = if_else(recommended %in% MTD_true[[scenarioID]], 1, 0)) |>
ungroup() |>
summarise(PCS = mean(correct), .by = scenarioID)
d_res
#> # A tibble: 4 × 2
#> scenarioID PCS
#> <int> <dbl>
Expand Down
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ escalation_rule <- learn_escalation_rule(

escalation_rule
#> <EscalationRule>
#> dir: escalation_rules/20241231_130626
#> created at: 2024-12-31 14:23:31
#> dir: escalation_rules/20250101_162633
#> created at: 2025-01-01 17:43:23
#> call:
#> learn_escalation_rule(J = 6, target = 0.25, epsilon = 0.04, delta = 0.1,
#> N_total = 36, N_cohort = 3, seed = 123, rl_config = rl_config_set(iter = 1000))
Expand Down Expand Up @@ -143,12 +143,14 @@ library(dplyr)

MTD_true <- list("MTD_6", c("MTD_3", "MTD_4"), "no_MTD", "MTD_4")

d_sim |>
d_res <- d_sim |>
filter(cohortID == max(cohortID), .by = c(scenarioID, simID)) |>
rowwise() |>
mutate(correct = if_else(recommended %in% MTD_true[[scenarioID]], 1, 0)) |>
ungroup() |>
summarise(PCS = mean(correct), .by = scenarioID)

d_res
#> # A tibble: 4 × 2
#> scenarioID PCS
#> <int> <dbl>
Expand Down
2 changes: 1 addition & 1 deletion inst/python/DoseEscalationEnv.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def _compute_state(self):
Returns:
The specific value of the state s.
"""
is_final = 1.0 if np.sum(self.Ns) == self.N_total else 0.0
is_final = 0.2 if np.sum(self.Ns) == self.N_total else 0.1

return np.concatenate((
np.array([self.current_dose / (self.J - 1)]),
Expand Down
4 changes: 2 additions & 2 deletions man/rl_config_set.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 5 additions & 3 deletions vignettes/RLescalation.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ escalation_rule <- learn_escalation_rule(
escalation_rule
#> <EscalationRule>
#> dir: escalation_rules/20241231_130626
#> created at: 2024-12-31 14:23:31
#> dir: escalation_rules/20250101_162633
#> created at: 2025-01-01 17:43:23
#> call:
#> learn_escalation_rule(J = 6, target = 0.25, epsilon = 0.04, delta = 0.1,
#> N_total = 36, N_cohort = 3, seed = 123, rl_config = rl_config_set(iter = 1000))
Expand Down Expand Up @@ -145,12 +145,14 @@ library(dplyr)
MTD_true <- list("MTD_6", c("MTD_3", "MTD_4"), "no_MTD", "MTD_4")
d_sim |>
d_res <- d_sim |>
filter(cohortID == max(cohortID), .by = c(scenarioID, simID)) |>
rowwise() |>
mutate(correct = if_else(recommended %in% MTD_true[[scenarioID]], 1, 0)) |>
ungroup() |>
summarise(PCS = mean(correct), .by = scenarioID)
d_res
#> # A tibble: 4 × 2
#> scenarioID PCS
#> <int> <dbl>
Expand Down

0 comments on commit a1c9dfc

Please sign in to comment.