Last active
June 14, 2025 16:55
-
-
Save topepo/b1c428316e16ae77e9e33ba8a36b4ea4 to your computer and use it in GitHub Desktop.
Simple R shiny app for SGD
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
ui <- page_sidebar( | |
sidebar = sidebar( | |
bg = "white", | |
accordion( | |
open = c(), | |
accordion_panel( | |
"SGD Parameters", | |
sliderInput( | |
inputId = "epochs", | |
label = "# Log10 Epochs", | |
min = 1, | |
max = 3, | |
value = 2, | |
step = 0.5, | |
width = "100%" | |
), | |
sliderInput( | |
inputId = "batches", | |
label = "Log2 Batch Size", | |
min = 0, | |
max = 7, | |
value = 0, | |
width = "100%", | |
step = 1 | |
), | |
sliderInput( | |
inputId = "rate", | |
label = "Log10 Learn. Rate", | |
min = -3, | |
max = 0, | |
value = -1, | |
step = 0.5, | |
width = "100%" | |
) | |
), | |
accordion_panel( | |
"Starting Values", | |
sliderInput( | |
inputId = "intercept", | |
label = "Intercept", | |
min = -2, | |
max = 3.5, | |
value = 0, | |
step = 0.25, | |
width = "100%" | |
), # intercept | |
sliderInput( | |
inputId = "slope", | |
label = "Slope", | |
min = -4, | |
max = 1, | |
value = 0, | |
step = 0.25, | |
width = "100%" | |
) # slope | |
) | |
) | |
), | |
plotOutput("sgd") | |
) | |
# Define server logic required to draw a histogram ---- | |
server <- function(input, output) { | |
true_means <- c(1.0, -1.5) | |
names(true_means) <- c("alpha", "beta") | |
num_samples <- 100 | |
# ---------------------------------------------------------------------------- | |
# Generate data | |
set.seed(1) | |
x <- runif(num_samples, min = -1) | |
set.seed(2) | |
tr_data <- tibble::tibble( | |
x = x, | |
lp = true_means[1] + true_means[2] * x, | |
truth = plogis(lp), | |
random = runif(num_samples), | |
class = factor(ifelse(random <= truth, "A", "B")) | |
) | |
x_mat <- cbind(rep(1, num_samples), tr_data$x) | |
y_bin <- ifelse(tr_data$class == "A", 1, 0) | |
# ---------------------------------------------------------------------------- | |
# Functions | |
logistic_values <- function(param, x, y) { | |
alpha <- param[1] | |
beta <- param[2] | |
.expr4 <- exp(-(alpha + x * beta)) | |
.expr5 <- 1 + .expr4 | |
.expr6 <- 1/.expr5 | |
.expr9 <- 1 - y | |
.expr10 <- 1 - .expr6 | |
.expr15 <- .expr5^2 | |
.expr16 <- .expr4/.expr15 | |
.expr24 <- .expr4 * x/.expr15 | |
.value <- -(y * log(.expr6) + .expr9 * log(.expr10)) | |
.grad <- array(0, c(length(.value), 2L), list(NULL, c("alpha", "beta"))) | |
.grad[, "alpha"] <- -(y * (.expr16/.expr6) - .expr9 * (.expr16/.expr10)) | |
.grad[, "beta"] <- -(y * (.expr24/.expr6) - .expr9 * (.expr24/.expr10)) | |
list(obj = mean(.value), gradient = apply(.grad, 2, mean)) | |
} | |
obj_fun <- function(param, x, y) { | |
logistic_values(param, x, y)$obj | |
} | |
grad_fun <- function(param, x, y) { | |
logistic_values(param, x, y)$gradient | |
} | |
make_batches <- function(x, samples = 1) { | |
n <- length(x) | |
batches <- floor(n / samples) | |
ind <- 1:batches | |
grps <- rep_len(ind, n) | |
sample(grps) | |
} | |
# ------------------------------------------------------------------------------ | |
# background grid | |
grid_size <- 20 | |
grid <- | |
crossing( | |
alpha = seq(-2, 3.5, length.out = grid_size), | |
beta = seq(-4, 1, length.out = grid_size), | |
) |> | |
mutate( | |
nllh = purrr::map2_dbl(alpha, beta, ~ obj_fun(c(.x, .y), x = x, y = y_bin)) | |
) | |
# ------------------------------------------------------------------------------ | |
output$sgd <- | |
renderPlot({ | |
num_terms <- 2 | |
num_tr <- length(y_bin) | |
epochs <- input$epochs | |
rate <- input$rate | |
batch_size <- input$batches | |
int <- input$intercept | |
slp <- input$slope | |
epochs <- floor(10^epochs) | |
rate <- rate_i <- 10^rate | |
rate_dilution <- 1 | |
batch_size <- 2^batch_size | |
batch_size <- min(batch_size, 100) | |
beta <- c(int, slp) | |
names(beta) <- c("alpha", "beta") | |
params <- matrix(NA, nrow = num_tr * epochs, ncol = num_terms + 1) | |
num_batches <- | |
make_batches(y_bin, batch_size) |> | |
vctrs::vec_unique_count() | |
llh_start <- | |
as_tibble_row(beta) |> | |
mutate(ind = 0) | |
llh <- | |
crossing(epoch = 1:epochs, iteration = 1:num_batches) |> | |
mutate( | |
alpha = NA_real_, | |
beta = NA_real_, | |
ind = row_number() | |
) | |
cnt <- 0 | |
for (.ep in seq_len(epochs)) { | |
btch <- make_batches(y_bin, batch_size) | |
num_btch <- vctrs::vec_unique_count(btch) | |
for (.row in 1:num_btch) { | |
cnt <- cnt + 1 | |
llh$alpha[cnt] <- beta[1] | |
llh$beta[cnt] <- beta[2] | |
btch_ind <- which(btch == .row) | |
x_iter <- x_mat[btch_ind, 2] | |
y_iter <- y_bin[btch_ind] | |
grd <- grad_fun(beta, x_iter, y_iter) | |
beta <- beta - (rate * grd) | |
} | |
rate_i <- rate_i / rate_dilution | |
} | |
best <- | |
bind_rows( | |
llh_start, | |
llh |> slice_head(n = 1), | |
llh |> slice_max(ind, n = 1, by = c(epoch)) | |
) | |
per_batch <- round(mean(table(btch)), 1) | |
ttl <- cli::format_inline( | |
"{epochs} epochs, {per_batch} sample{?s} per batch, rate: {signif(rate, 2)}") | |
p <- | |
grid |> | |
ggplot(aes(alpha, beta)) + | |
geom_contour_filled(aes(z = nllh), bins = 11, alpha = 3 / 4, show.legend = FALSE) + | |
geom_path(data = best, col = "red") + | |
geom_hline(yintercept = true_means[2], col = "white", lty = 3)+ | |
geom_vline(xintercept = true_means[1], col = "white", lty = 3) + | |
coord_fixed(ratio = 1) + | |
theme_minimal() + | |
labs(x = expression(beta[0]), y = expression(beta[1]), | |
title = ttl) | |
print(p) | |
}, | |
res = 90) | |
} | |
app <- shinyApp(ui = ui, server = server) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment