This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using CounterfactualExplanations.Data: load_mnist | |
using CSV | |
using DataFrames | |
using Flux | |
using GMT | |
using Images | |
using LinearAlgebra | |
using MLJBase | |
using MLJModels | |
using OneHotArrays |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using ConformalPrediction | |
using Distributions | |
using MLJ | |
using Plots | |
# Inputs: | |
N = 600 | |
xmax = 3.0 | |
d = Uniform(-xmax, xmax) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Simple | |
"The `SimpleInductiveClassifier` is the simplest approach to Inductive Conformal Classification. Contrary to the [`NaiveClassifier`](@ref) it computes nonconformity scores using a designated calibration dataset." | |
mutable struct SimpleInductiveClassifier{Model <: Supervised} <: ConformalSet | |
model::Model | |
coverage::AbstractFloat | |
scores::Union{Nothing,AbstractArray} | |
heuristic::Function | |
train_ratio::AbstractFloat | |
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import CounterfactualExplanations.Generators: ∂ℓ | |
using LinearAlgebra | |
# Countefactual loss: | |
function ∂ℓ( | |
generator::AbstractGradientBasedGenerator, | |
counterfactual_state::CounterfactualState) | |
M = counterfactual_state.M | |
nn = M.nn | |
x′ = counterfactual_state.x′ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Flux | |
using CounterfactualExplanations, CounterfactualExplanations.Models | |
import CounterfactualExplanations.Models: logits, probs # import functions in order to extend | |
# Step 1) | |
struct TorchNetwork <: Models.AbstractFittedModel | |
nn::Any | |
end | |
# Step 2) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Import libraries. | |
using Flux, Plots, Random, PlotThemes, Statistics, BayesLaplace | |
theme(:wong) | |
# Toy data: | |
xs, y = toy_data_linear(100) | |
X = hcat(xs...); # bring into tabular format | |
data = zip(xs,y) | |
# Build MLP: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Import libraries. | |
using Flux, Plots, Random, PlotThemes, Statistics, BayesLaplace | |
theme(:wong) | |
# Toy data: | |
xs, y = toy_data_linear(100) | |
X = hcat(xs...); # bring into tabular format | |
data = zip(xs,y) | |
# Neural network: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Newton's Method | |
function arminjo(𝓁, g_t, θ_t, d_t, args, ρ, c=1e-4) | |
𝓁(θ_t .+ ρ .* d_t, args...) <= 𝓁(θ_t, args...) .+ c .* ρ .* d_t'g_t | |
end | |
function newton(𝓁, θ, ∇𝓁, ∇∇𝓁, args; max_iter=100, τ=1e-5) | |
# Intialize: | |
converged = false # termination state | |
t = 1 # iteration count | |
θ_t = θ # initial parameters |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Loss: | |
function 𝓁(w,w_0,H_0,X,y) | |
N = length(y) | |
D = size(X)[2] | |
μ = sigmoid(w,X) | |
Δw = w-w_0 | |
l = - ∑( y[n] * log(μ[n]) + (1-y[n]) * log(1-μ[n]) for n=1:N) + 1/2 * Δw'H_0*Δw | |
return l | |
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
logit <- function(X, y, beta_0=NULL, tau=1e-9, max_iter=10000) { | |
if(!all(X[,1]==1)) { | |
X <- cbind(1,X) | |
} | |
p <- ncol(X) | |
n <- nrow(X) | |
# Initialization: ---- | |
if (is.null(beta_0)) { | |
beta_latest <- matrix(rep(0, p)) # naive first guess | |
} |
NewerOlder