Skip to content

Instantly share code, notes, and snippets.

@So-Cool
Last active November 4, 2025 12:16
Show Gist options
  • Save So-Cool/b413a73be3ed3e2a717ac47d086c80eb to your computer and use it in GitHub Desktop.
Save So-Cool/b413a73be3ed3e2a717ac47d086c80eb to your computer and use it in GitHub Desktop.
ADM+S 2025 hands-on
fat-forensics==0.1.1
ipywidgets>=7.7.0
matplotlib>=3.3.0
numpy>=1.20.1
scikit-learn>=1.1.0
scipy>=1.6.1
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
"""
Tabular Surrogate Explainer Builder
===================================
This module implements a collection of helper functions for an interactive
builder of surrogate explainers of tabular data based on iPyWidgets.
See <https://github.com/fat-forensics/resources/tree/master/tabular_surrogate_builder>
for more details.
"""
# Author: Kacper Sokol <[email protected]>
# License: new BSD
from sklearn.linear_model import RidgeClassifier
import numpy as np
import ipywidgets as widgets
import matplotlib.pyplot as plt
plt.style.use('seaborn-v0_8')
def build_tabular_blimey(
instance,
class_to_explain,
sampled_data,
prediction_fn,
discretisation,
fit_intercept=True,
random_seed=42
):
"""
Composes a tabular bLIMEy surrogate explainer based on ridge classification.
"""
preds = (prediction_fn(sampled_data) == class_to_explain).astype(np.int8)
# digitize data
data_dig = np.vstack([
np.digitize(sampled_data[:, 0], discretisation[0]),
np.digitize(sampled_data[:, 1], discretisation[1])
]).T
# digitize point
point_dig = np.array([
np.digitize(instance[0], discretisation[0]),
np.digitize(instance[1], discretisation[1])
])
#
binary_data = (data_dig == point_dig).astype(np.int8)
# np.unique(binary_data, axis=0)
# train ridge
clf = RidgeClassifier(fit_intercept=fit_intercept,
random_state=random_seed)
clf.fit(binary_data, preds)
# return coefficients
return clf.coef_
def plot_tabular_explanation(
instance,
class_to_explain,
sampled_data,
prediction_fn,
discretisation,
feature_ranges,
feature_names,
explanation
):
"""Plots a tabular bLIMEy explanation."""
explanation = explanation.copy() / np.abs(explanation).sum()
fig, (ax_l, ax_r) = plt.subplots(1, 2, figsize=(18, 6))
fig.patch.set_alpha(0)
fig.suptitle('Explained class: {}'.format(class_to_explain), fontsize=18)
# plot /petal length (cm)/ vs /petal width/
# x_name, y_name = 'petal length (cm)', 'petal width (cm)'
# x_ind, y_ind = iris.feature_names.index(x_name), iris.feature_names.index(y_name)
x_min, x_max = feature_ranges[0][0] - .5, feature_ranges[0][1] + .5
y_min, y_max = feature_ranges[1][0] - .5, feature_ranges[1][1] + .5
#
plot_step = 0.02
xx, yy = np.meshgrid(np.arange(x_min, x_max, plot_step),
np.arange(y_min, y_max, plot_step))
#plt.tight_layout(h_pad=0.5, w_pad=0.5, pad=2.5)
Z = prediction_fn(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
ax_l.contourf(xx, yy, Z, cmap=plt.cm.RdYlBu)
#
ax_l.scatter(sampled_data[:, 0],
sampled_data[:, 1],
c=prediction_fn(sampled_data),
cmap=plt.cm.Set1, edgecolor='k')
ax_l.set_xlabel(feature_names[0], fontsize=18)
ax_l.set_ylabel(feature_names[1], fontsize=18)
#
ax_l.set_xlim(x_min, x_max)
ax_l.set_ylim(y_min, y_max)
# plt.xticks(())
# plt.yticks(())
ax_l.scatter(instance[0], instance[1],
c='yellow', marker='*', s=500, edgecolor='k')
ax_l.vlines(discretisation[0], -1, 10, linewidth=3)
ax_l.hlines(discretisation[1], -1, 10, linewidth=3)
#
ax_l.tick_params(axis='x', labelsize=18)
ax_l.tick_params(axis='y', labelsize=18)
x_dig_ = np.digitize(instance[0], discretisation[0])
x_dig_list_ = ['-inf'] + [str(i) for i in discretisation[0]] + ['+inf']
y_dig_ = np.digitize(instance[1], discretisation[1])
y_dig_list_ = ['-inf'] + [str(i) for i in discretisation[1]] + ['+inf']
x = ['{}\n{} < ... <= {}'.format(feature_names[0],
x_dig_list_[x_dig_],
x_dig_list_[x_dig_ + 1]),
'{}\n{} < ... <= {}'.format(feature_names[1],
y_dig_list_[y_dig_],
y_dig_list_[y_dig_ + 1])]
#
y = [abs(i) for i in explanation]
c = ['green' if i >= 0 else 'red' for i in explanation]
ax_r.set_xlim([0, 1.20])
ax_r.set_ylim([-.5, len(x) - .5])
ax_r.grid(False, axis='y')
ax_r.set_xticks([0, 0.2, 0.4, 0.6, 0.8, 1])
ax_r.barh(x, y, height=.5, color=c)
ax_r.set_yticklabels([])
for i, v in enumerate(y):
ax_r.text(v + .02, i + .15, '{:.4f}'.format(v),
fontweight='bold', fontsize=18)
ax_r.text(v + .02, i - .2, x[i], fontweight='bold', fontsize=18)
# highlight explained spot
x_dig_list_val_ = [0] + [i for i in discretisation[0]] + [8]
ax_l.axvspan(x_dig_list_val_[x_dig_], x_dig_list_val_[x_dig_ + 1],
facecolor='None', hatch='/', alpha=1.0)
y_dig_list_val_ = [-.5] + [i for i in discretisation[1]] + [3.5]
ax_l.axhspan(y_dig_list_val_[y_dig_], y_dig_list_val_[y_dig_ + 1],
facecolor='None', hatch='\\', alpha=1.0)
ax_r.tick_params(axis='x', labelsize=18)
# ax_r.tick_params(axis='y', labelsize=18)
plt.show()
def _generate_data(data_samples_no, x_range, y_range, random_seed):
"""
Generates a random data sample with a fixed number of instances per split.
"""
x_range_n, y_range_n = len(x_range), len(y_range)
assert x_range_n > 2 and y_range_n > 2
y_range_ = y_range[::-1]
data_, i = [], 0
for y_ in range(y_range_n - 1):
for x_ in range(x_range_n - 1):
np.random.seed(random_seed)
d_ = np.random.uniform(
low=(x_range[x_] + 0.1, y_range_[y_ + 1] + 0.1),
high=(x_range[x_ + 1] - 0.1, y_range_[y_] - 0.1),
size=(data_samples_no[i], 2)
)
data_.append(d_)
i += 1
data_ = np.vstack(data_)
return data_
def generate_tabular_widget(
black_boxes,
class_map,
feature_specification,
random_seed=42
):
"""Builds iPyWidget interactive tabular surrogate explainer."""
def explain_action(obj):
prediction_fn_ = black_boxes[lime_bb_toggle.value]
instance_ = np.array(
[lime_instance_toggle.children[0].value,
lime_instance_toggle.children[1].value],
dtype=np.float64
)
feature_names_ = {
0: feature_specification[0]['name'],
1: feature_specification[1]['name']
}
feature_ranges_ = {
0: feature_specification[0]['range'],
1: feature_specification[1]['range']
}
x_axis_range_ = x_axis_slider.value
assert x_axis_range_[0] != x_axis_range_[1], (
'The petal length split values must not be identical.')
y_axis_range_ = y_axis_slider.value
assert y_axis_range_[0] != y_axis_range_[1], (
'The petal width split values must not be identical.')
discretisation_ = {0: x_axis_range_, 1: y_axis_range_}
data_samples_no = [i.value for i in sample_widget.children]
x_range_ = [feature_specification[0]['range'][0],
x_axis_range_[0],
x_axis_range_[1],
feature_specification[0]['range'][1]]
y_range_ = [feature_specification[1]['range'][0],
y_axis_range_[0],
y_axis_range_[1],
feature_specification[1]['range'][1]]
data_ = _generate_data(data_samples_no, x_range_, y_range_, random_seed)
explained_class_ = lime_class_toggle.value
explained_class_id_ = class_map[explained_class_]
explanation_ = build_tabular_blimey(
instance_,
explained_class_id_,
data_,
prediction_fn_,
discretisation_,
fit_intercept=model_intercept_widget.value,
random_seed=random_seed
)
with lime_explain_out:
lime_explain_out.clear_output(wait=True)
plot_tabular_explanation(
instance_,
explained_class_,
data_,
prediction_fn_,
discretisation_,
feature_ranges_,
feature_names_,
explanation_
)
plt.show()
# Explained instance -- select 1 of three points
lime_instance_toggle_items = [
widgets.BoundedFloatText(
value=feature_specification[0]['instance']['value'],
min=feature_specification[0]['range'][0],
max=feature_specification[0]['range'][1],
step=feature_specification[0]['instance']['step'],
description='[X] {}:'.format(feature_specification[0]['name']),
disabled=False
),
widgets.BoundedFloatText(
value=feature_specification[1]['instance']['value'],
min=feature_specification[1]['range'][0],
max=feature_specification[1]['range'][1],
step=feature_specification[1]['instance']['step'],
description='[Y] {}:'.format(feature_specification[1]['name']),
disabled=False
)
]
lime_instance_toggle = widgets.GridBox(
lime_instance_toggle_items,
layout=widgets.Layout(grid_template_columns='repeat(2, 315px)')
)
# Select a class to explain
lime_class_toggle = widgets.ToggleButtons(
options=list(class_map.keys()),
description='Class:',
disabled=False,
button_style='' # 'success', 'info', 'warning', 'danger' or ''
# tooltips=['Description of slow', 'Description of regular',
# 'Description of fast'],
# icons=['check'] * 3
)
items_sample_widget = [
widgets.BoundedIntText(
value=100,
min=0,
max=1000,
step=1,
description='Samples #:',
disabled=False
)
for _ in range(9)
]
sample_widget = widgets.GridBox(
items_sample_widget,
layout=widgets.Layout(grid_template_columns='repeat(3, 300px)'))
# select two thresholds for the segmentation
step_ = str(feature_specification[0]['discretisation']['step'])
prec_ = len(step_.split('.')[1]) if '.' in step_ else 0
x_axis_slider = widgets.FloatRangeSlider(
value=feature_specification[0]['discretisation']['init'],
min=feature_specification[0]['discretisation']['range'][0],
max=feature_specification[0]['discretisation']['range'][1],
step=feature_specification[0]['discretisation']['step'],
description='[X] {}:'.format(feature_specification[0]['name']),
disabled=False,
continuous_update=False,
orientation='horizontal',
readout=True,
readout_format='.{}f'.format(prec_)
)
step_ = str(feature_specification[1]['discretisation']['step'])
prec_ = len(step_.split('.')[1]) if '.' in step_ else 0
y_axis_slider = widgets.FloatRangeSlider(
value=feature_specification[1]['discretisation']['init'],
min=feature_specification[1]['discretisation']['range'][0],
max=feature_specification[1]['discretisation']['range'][1],
step=feature_specification[1]['discretisation']['step'],
description='[Y] {}:'.format(feature_specification[1]['name']),
disabled=False,
continuous_update=False,
orientation='horizontal',
readout=True,
readout_format='.{}f'.format(prec_)
)
model_intercept_widget = widgets.Checkbox(
value=True,
description='Model intercept?',
disabled=False,
indent=False
)
lime_bb_toggle = widgets.ToggleButtons(
options=list(black_boxes.keys()),
description='Black box:',
disabled=False,
button_style=''
)
lime_explain_button = widgets.Button(
description='Explain!',
disabled=False,
button_style='info',
tooltip='Explain',
icon='check'
)
lime_explain_out = widgets.Output()
lime_explain_button.on_click(explain_action)
# pre-click the button
lime_explain_button._click_handlers(lime_explain_button)
interactive_explainer = widgets.VBox([
lime_instance_toggle,
lime_class_toggle,
sample_widget,
x_axis_slider,
y_axis_slider,
model_intercept_widget,
lime_bb_toggle,
lime_explain_button,
lime_explain_out
])
return interactive_explainer

Comments are disabled for this gist.