Skip to content

Instantly share code, notes, and snippets.

@topepo
Last active June 14, 2025 16:55
Show Gist options
  • Save topepo/b1c428316e16ae77e9e33ba8a36b4ea4 to your computer and use it in GitHub Desktop.
Save topepo/b1c428316e16ae77e9e33ba8a36b4ea4 to your computer and use it in GitHub Desktop.
Simple R shiny app for SGD
ui <- page_sidebar(
sidebar = sidebar(
bg = "white",
accordion(
open = c(),
accordion_panel(
"SGD Parameters",
sliderInput(
inputId = "epochs",
label = "# Log10 Epochs",
min = 1,
max = 3,
value = 2,
step = 0.5,
width = "100%"
),
sliderInput(
inputId = "batches",
label = "Log2 Batch Size",
min = 0,
max = 7,
value = 0,
width = "100%",
step = 1
),
sliderInput(
inputId = "rate",
label = "Log10 Learn. Rate",
min = -3,
max = 0,
value = -1,
step = 0.5,
width = "100%"
)
),
accordion_panel(
"Starting Values",
sliderInput(
inputId = "intercept",
label = "Intercept",
min = -2,
max = 3.5,
value = 0,
step = 0.25,
width = "100%"
), # intercept
sliderInput(
inputId = "slope",
label = "Slope",
min = -4,
max = 1,
value = 0,
step = 0.25,
width = "100%"
) # slope
)
)
),
plotOutput("sgd")
)
# Define server logic required to draw a histogram ----
server <- function(input, output) {
true_means <- c(1.0, -1.5)
names(true_means) <- c("alpha", "beta")
num_samples <- 100
# ----------------------------------------------------------------------------
# Generate data
set.seed(1)
x <- runif(num_samples, min = -1)
set.seed(2)
tr_data <- tibble::tibble(
x = x,
lp = true_means[1] + true_means[2] * x,
truth = plogis(lp),
random = runif(num_samples),
class = factor(ifelse(random <= truth, "A", "B"))
)
x_mat <- cbind(rep(1, num_samples), tr_data$x)
y_bin <- ifelse(tr_data$class == "A", 1, 0)
# ----------------------------------------------------------------------------
# Functions
logistic_values <- function(param, x, y) {
alpha <- param[1]
beta <- param[2]
.expr4 <- exp(-(alpha + x * beta))
.expr5 <- 1 + .expr4
.expr6 <- 1/.expr5
.expr9 <- 1 - y
.expr10 <- 1 - .expr6
.expr15 <- .expr5^2
.expr16 <- .expr4/.expr15
.expr24 <- .expr4 * x/.expr15
.value <- -(y * log(.expr6) + .expr9 * log(.expr10))
.grad <- array(0, c(length(.value), 2L), list(NULL, c("alpha", "beta")))
.grad[, "alpha"] <- -(y * (.expr16/.expr6) - .expr9 * (.expr16/.expr10))
.grad[, "beta"] <- -(y * (.expr24/.expr6) - .expr9 * (.expr24/.expr10))
list(obj = mean(.value), gradient = apply(.grad, 2, mean))
}
obj_fun <- function(param, x, y) {
logistic_values(param, x, y)$obj
}
grad_fun <- function(param, x, y) {
logistic_values(param, x, y)$gradient
}
make_batches <- function(x, samples = 1) {
n <- length(x)
batches <- floor(n / samples)
ind <- 1:batches
grps <- rep_len(ind, n)
sample(grps)
}
# ------------------------------------------------------------------------------
# background grid
grid_size <- 20
grid <-
crossing(
alpha = seq(-2, 3.5, length.out = grid_size),
beta = seq(-4, 1, length.out = grid_size),
) |>
mutate(
nllh = purrr::map2_dbl(alpha, beta, ~ obj_fun(c(.x, .y), x = x, y = y_bin))
)
# ------------------------------------------------------------------------------
output$sgd <-
renderPlot({
num_terms <- 2
num_tr <- length(y_bin)
epochs <- input$epochs
rate <- input$rate
batch_size <- input$batches
int <- input$intercept
slp <- input$slope
epochs <- floor(10^epochs)
rate <- rate_i <- 10^rate
rate_dilution <- 1
batch_size <- 2^batch_size
batch_size <- min(batch_size, 100)
beta <- c(int, slp)
names(beta) <- c("alpha", "beta")
params <- matrix(NA, nrow = num_tr * epochs, ncol = num_terms + 1)
num_batches <-
make_batches(y_bin, batch_size) |>
vctrs::vec_unique_count()
llh_start <-
as_tibble_row(beta) |>
mutate(ind = 0)
llh <-
crossing(epoch = 1:epochs, iteration = 1:num_batches) |>
mutate(
alpha = NA_real_,
beta = NA_real_,
ind = row_number()
)
cnt <- 0
for (.ep in seq_len(epochs)) {
btch <- make_batches(y_bin, batch_size)
num_btch <- vctrs::vec_unique_count(btch)
for (.row in 1:num_btch) {
cnt <- cnt + 1
llh$alpha[cnt] <- beta[1]
llh$beta[cnt] <- beta[2]
btch_ind <- which(btch == .row)
x_iter <- x_mat[btch_ind, 2]
y_iter <- y_bin[btch_ind]
grd <- grad_fun(beta, x_iter, y_iter)
beta <- beta - (rate * grd)
}
rate_i <- rate_i / rate_dilution
}
best <-
bind_rows(
llh_start,
llh |> slice_head(n = 1),
llh |> slice_max(ind, n = 1, by = c(epoch))
)
per_batch <- round(mean(table(btch)), 1)
ttl <- cli::format_inline(
"{epochs} epochs, {per_batch} sample{?s} per batch, rate: {signif(rate, 2)}")
p <-
grid |>
ggplot(aes(alpha, beta)) +
geom_contour_filled(aes(z = nllh), bins = 11, alpha = 3 / 4, show.legend = FALSE) +
geom_path(data = best, col = "red") +
geom_hline(yintercept = true_means[2], col = "white", lty = 3)+
geom_vline(xintercept = true_means[1], col = "white", lty = 3) +
coord_fixed(ratio = 1) +
theme_minimal() +
labs(x = expression(beta[0]), y = expression(beta[1]),
title = ttl)
print(p)
},
res = 90)
}
app <- shinyApp(ui = ui, server = server)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment