Created
May 7, 2020 19:18
-
-
Save gavinsimpson/727900bcd634fa530d2bd7316aa9d065 to your computer and use it in GitHub Desktop.
Animated spline basis functions
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## plots and animations of how basis functions are used to make | |
## splines and how these are fitted to data | |
library('ggplot2') | |
library('tibble') | |
library('tidyr') | |
library('dplyr') | |
library('mgcv') | |
library('mvnfast') | |
library('purrr') | |
library('gganimate') | |
theme_set(theme_minimal()) | |
f <- function(x) { | |
x^11 * (10 * (1 - x))^6 + ((10 * (10 * x)^3) * (1 - x)^10) | |
} | |
draw_beta <- function(n, k, mu = 1, sigma = 1) { | |
rmvn(n = n, mu = rep(mu, k), sigma = diag(rep(sigma, k))) | |
} | |
weight_basis <- function(bf, x, n = 1, k, ...) { | |
beta <- draw_beta(n = n, k = k, ...) | |
out <- sweep(bf, 2L, beta, '*') | |
colnames(out) <- paste0('f', seq_along(beta)) | |
out <- as_tibble(out) | |
out <- add_column(out, x = x) | |
out <- pivot_longer(out, -x, names_to = 'bf', values_to = 'y') | |
out | |
} | |
random_bases <- function(bf, x, draws = 10, k, ...) { | |
out <- rerun(draws, weight_basis(bf, x = x, k = k, ...)) | |
out <- bind_rows(out) | |
out <- add_column(out, draw = rep(seq_len(draws), each = length(x) * k), | |
.before = 1L) | |
class(out) <- c("random_bases", class(out)) | |
out | |
} | |
plot.random_bases <- function(x, facet = FALSE) { | |
plt <- ggplot(x, aes(x = x, y = y, colour = bf)) + | |
geom_line(lwd = 1) + | |
guides(colour = FALSE) | |
if (facet) { | |
plt + facet_wrap(~ draw) | |
} | |
plt | |
} | |
normalize <- function(x) { | |
rx <- range(x) | |
z <- (x - rx[1]) / (rx[2] - rx[1]) | |
z | |
} | |
set.seed(1) | |
N <- 500 | |
data <- tibble(x = runif(N), | |
ytrue = f(x), | |
ycent = ytrue - mean(ytrue), | |
yobs = ycent + rnorm(N, sd = 0.5), | |
ynorm = normalize(yobs)) | |
k <- 10 | |
knots <- with(data, list(x = seq(min(x), max(x), length = k))) | |
sm <- smoothCon(s(x, k = k, bs = "cr"), data = data, knots = knots)[[1]]$X | |
colnames(sm) <- levs <- paste0("f", seq_len(k)) | |
basis <- pivot_longer(cbind(sm, data), -(x:yobs), names_to = 'bf') | |
basis | |
ggplot(basis, aes(x = x, y = value, colour = bf)) + | |
geom_line(lwd = 1, alpha = 0.4) + | |
guides(colour = FALSE) | |
set.seed(2) | |
bfuns <- random_bases(sm, data$x, draws = 20, k = k) | |
smooth <- bfuns %>% | |
group_by(draw, x) %>% | |
summarise(spline = sum(y)) %>% | |
ungroup() | |
p1 <- ggplot(smooth) + | |
geom_line(data = smooth, aes(x = x, y = spline), lwd = 1.5) + | |
labs(y = 'f(x)', x = 'x') + | |
theme_minimal(base_size = 14, base_family = 'Titillium') | |
p1 | |
smooth_funs <- animate( | |
p1 + transition_states(draw, transition_length = 4, state_length = 2) + | |
ease_aes('cubic-in-out'), | |
nframes = 200, height = 800/1.77777, width = 800, res = 120) | |
p <- plot(bfuns) + geom_line(data = smooth, aes(x = x, y = spline), | |
inherit.aes = FALSE, lwd = 1.5) + | |
theme(text = element_text(size = 16)) + | |
labs(x = 'x', y = expression(f(x))) | |
animate( | |
p + transition_states(draw, transition_length = 4, state_length = 2) + | |
ease_aes('cubic-in-out'), | |
nframes = 200) | |
data_plt <- ggplot(data, aes(x = x, y = ycent)) + | |
geom_line(col = 'goldenrod', lwd = 2) + | |
geom_point(aes(y = yobs), alpha = 0.2) + | |
theme(text = element_text(size = 16)) | |
data_plt | |
sm2 <- smoothCon(s(x, k = k, bs = "cr"), data = data, knots = knots)[[1]]$X | |
beta <- coef(lm(ycent ~ sm2 - 1, data = data)) | |
wtbasis <- sweep(sm2, 2L, beta, FUN = "*") | |
colnames(wtbasis) <- colnames(sm2) <- paste0("F", seq_len(k)) | |
## create stacked unweighted and weighted basis | |
basis <- as_tibble(rbind(sm2, wtbasis)) %>% | |
add_column(x = rep(data$x, times = 2), | |
type = rep(c('unweighted', 'weighted'), each = nrow(sm2)), | |
.before = 1L) | |
##data <- cbind(data, fitted = rowSums(scbasis)) | |
wtbasis <- as_tibble(rbind(sm2, wtbasis)) %>% | |
add_column(x = rep(data$x, times = 2), | |
fitted = rowSums(.), | |
type = rep(c('unweighted', 'weighted'), each = nrow(sm2))) %>% | |
pivot_longer(-(x:type), names_to = 'bf') | |
basis <- pivot_longer(basis, -(x:type), names_to = 'bf') | |
ggplot(wtbasis, aes(x = x, y = value, colour = bf)) + | |
geom_line(lwd = 1, alpha = 0.4) + | |
guides(colour = FALSE) + | |
theme(text = element_text(size = 16)) + | |
labs(x = 'x', y = expression(f(x))) | |
ggplot(basis, aes(x = x, y = value, colour = bf)) + | |
geom_line(lwd = 1, alpha = 0.4) + | |
guides(colour = FALSE) + | |
theme(text = element_text(size = 16)) + | |
labs(x = 'x', y = expression(f(x))) | |
data_plt + geom_line(data = wtbasis, | |
mapping = aes(x = x, y = value, colour = bf), | |
lwd = 1, alpha = 0.5) + | |
guides(colour = FALSE) + | |
theme(text = element_text(size = 16)) | |
data_plt + geom_line(data = wtbasis, | |
mapping = aes(x = x, y = fitted), lwd = 1.5, colour = 'steelblue2', alpha = 0.75) + | |
geom_line(data = wtbasis, | |
mapping = aes(x = x, y = value, colour = bf), | |
lwd = 1, alpha = 0.7) + | |
guides(colour = FALSE) + | |
labs(y = expression(f(x)), x = 'x') + | |
theme(text = element_text(size = 16)) | |
p3 <- ggplot(data, aes(x = x, y = ycent)) + | |
geom_point(aes(y = yobs), alpha = 0.2) + | |
geom_line(data = basis, | |
mapping = aes(x = x, y = value, colour = bf), | |
lwd = 1, alpha = 0.5) + | |
geom_line(data = wtbasis, | |
mapping = aes(x = x, y = fitted), lwd = 1, colour = 'black', alpha = 0.75) + | |
guides(colour = FALSE) + | |
labs(y = 'f(x)', x = 'x', title = 'GAMs: learning from data', | |
subtitle = 'Cubic regression spline', | |
caption = '@ucfagls') + | |
theme_minimal(base_size = 14, base_family = 'Titillium') | |
p3 | |
animate(p3 + transition_states(type, transition_length = 4, state_length = 2) + | |
ease_aes('cubic-in-out'), | |
nframes = 100, height = 700, width = 800, res = 120) | |
sm_tprs <- smoothCon(s(x, k = k, bs = "tp"), absorb = TRUE, data = data)[[1]] | |
E <- t(mroot(sm_tprs$S[[1]])) # square root penalty | |
sm_tprsX <- rbind(sm_tprs$X, 0.1 * E) # augmented model matrix | |
y <- c(data$yobs, rep(0, nrow(E))) # augmented data | |
beta <- coef(lm(y ~ sm_tprsX, data = data)) | |
spline <- sweep(sm_tprs$X, 2L, beta[-1], FUN = "*") | |
sm_tprs <- sm_tprs$X | |
colnames(spline) <- colnames(sm_tprs) <- paste0("F", seq_len(k-1)) | |
## create stacked unweighted and weighted basis | |
basis <- as_tibble(rbind(sm_tprs, spline)) %>% | |
add_column(x = rep(data$x, times = 2), | |
type = rep(c('unweighted', 'weighted'), each = nrow(sm_tprs)), | |
.before = 1L) | |
##data <- cbind(data, fitted = rowSums(scbasis)) | |
spline <- as_tibble(rbind(sm_tprs, spline)) %>% | |
add_column(x = rep(data$x, times = 2), | |
fitted = rowSums(.) + beta[1L], | |
type = rep(c('unweighted', 'weighted'), each = nrow(sm_tprs))) %>% | |
pivot_longer(-(x:type), names_to = 'bf') | |
basis <- pivot_longer(basis, -(x:type), names_to = 'bf') | |
ptprs <- ggplot(data, aes(x = x, y = yobs)) + | |
geom_point(alpha = 0.2) + | |
geom_line(data = basis, | |
mapping = aes(x = x, y = value, colour = bf), | |
lwd = 1, alpha = 0.5) + | |
geom_line(data = spline, | |
mapping = aes(x = x, y = fitted), lwd = 1, colour = 'black', alpha = 0.75) + | |
guides(colour = FALSE) + | |
labs(y = 'f(x)', x = 'x', title = 'GAMs: learning from data', | |
subtitle = 'Penalised thin plate regression spline', | |
caption = '@ucfagls') + | |
theme_minimal(base_size = 14, base_family = 'Titillium') | |
ptprs | |
animate(ptprs + transition_states(type, transition_length = 4, state_length = 2) + | |
ease_aes('cubic-in-out'), | |
nframes = 100, height = 700, width = 800, res = 120) | |
sm_gp <- smoothCon(s(x, bs = 'gp', k = k, m = c(3, 0.25)), data = data)[[1]]$X | |
## E <- t(mroot(sm_gp$S[[1]])) # square root penalty | |
## sm_gpX <- rbind(sm_gp$X, 0.1 * # augmented model matrix | |
## y <- c(data$yobs, rep(0, nrow(E))) # augmented data | |
beta <- coef(lm(yobs ~ sm_gp - 1, data = data)) | |
spline <- sweep(sm_gp, 2L, beta, FUN = "*") | |
colnames(spline) <- colnames(sm_gp) <- paste0("F", seq_len(k)) | |
## create stacked unweighted and weighted basis | |
basis <- as_tibble(rbind(sm_gp, spline)) %>% | |
add_column(x = rep(data$x, times = 2), | |
type = rep(c('unweighted', 'weighted'), each = nrow(sm_gp)), | |
.before = 1L) | |
##data <- cbind(data, fitted = rowSums(scbasis)) | |
spline <- as_tibble(rbind(sm_gp, spline)) %>% | |
add_column(x = rep(data$x, times = 2), | |
fitted = rowSums(.), | |
type = rep(c('unweighted', 'weighted'), each = nrow(sm_gp))) %>% | |
pivot_longer(-(x:type), names_to = 'bf') | |
basis <- pivot_longer(basis, -(x:type), names_to = 'bf') | |
pgp <- ggplot(data, aes(x = x, y = yobs)) + | |
geom_point(alpha = 0.2) + | |
geom_line(data = basis, | |
mapping = aes(x = x, y = value, colour = bf), | |
lwd = 1, alpha = 0.5) + | |
geom_line(data = spline, | |
mapping = aes(x = x, y = fitted), lwd = 1, colour = 'black', alpha = 0.75) + | |
guides(colour = FALSE) + | |
labs(y = 'f(x)', x = 'x', title = 'GAMs: learning from data', | |
subtitle = 'Gaussian process — Matérn(κ=1.5; ρ=0.25)', | |
caption = '@ucfagls') + | |
theme_minimal(base_size = 14, base_family = 'Titillium') | |
pgp | |
animate(pgp + transition_states(type, transition_length = 4, state_length = 2) + | |
ease_aes('cubic-in-out'), | |
nframes = 100, height = 700, width = 800, res = 120) | |
ggplot(basis, aes(x = x, y = value, colour = bf)) + geom_line() + facet_wrap(~ type, scales = 'free_y') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment