Skip to content

Instantly share code, notes, and snippets.

@eggsyntax
Created November 12, 2024 15:24
Show Gist options
  • Save eggsyntax/6acd209f7bbc952ce7ede348b349e24a to your computer and use it in GitHub Desktop.
Save eggsyntax/6acd209f7bbc952ce7ede348b349e24a to your computer and use it in GitHub Desktop.
Log-loss heatmap
import React, { useState } from 'react';
const LogLossHeatmap = () => {
const [hoverInfo, setHoverInfo] = useState(null);
// Calculate log loss for a predictor who predicts with confidence p
// and has accuracy a
const calculateLogLoss = (confidence, accuracy) => {
// Handle edge cases to avoid infinity/NaN
if (confidence === 1.0) {
if (accuracy === 1.0) return 0;
if (accuracy === 0.0) return 4; // Cap at our max value
}
if (confidence === 0.5 && accuracy === 1.0) return Math.log(2);
// For correct predictions: -log(confidence)
// For incorrect predictions: -log(1-confidence)
const result = -(accuracy * Math.log(confidence) +
(1 - accuracy) * Math.log(1 - confidence));
// Cap at 4 for visualization purposes
return Math.min(4, result);
};
// Convert log loss to color using a viridis-like scale
const getColor = (score) => {
// Log loss of 4 is already quite bad (predicting with high confidence and being wrong)
const maxScore = 4;
const factor = Math.min(1, Math.max(0, score / maxScore));
// Define key colors in the scale
const colors = [
[253, 231, 37], // Yellow
[94, 201, 98], // Green
[32, 144, 141], // Teal
[59, 82, 139], // Blue
[68, 1, 84] // Purple
];
const numSegments = colors.length - 1;
const segment = Math.min(Math.floor(factor * numSegments), numSegments - 1);
const segmentFactor = (factor * numSegments) - segment;
const c1 = colors[segment];
const c2 = colors[segment + 1];
const red = Math.round(c1[0] + (c2[0] - c1[0]) * segmentFactor);
const green = Math.round(c1[1] + (c2[1] - c1[1]) * segmentFactor);
const blue = Math.round(c1[2] + (c2[2] - c1[2]) * segmentFactor);
return `rgb(${red},${green},${blue})`;
};
const width = 600;
const height = 450;
const margin = { top: 50, right: 100, bottom: 60, left: 60 };
const cells = 200; // High resolution for smooth gradient
const cellWidth = width / cells;
const cellHeight = height / cells;
// Generate cells for the heatmap
const cells_array = Array.from({ length: cells }, (_, i) => {
// Confidence ranges from 0.5 to 1.0
const confidence = 0.5 + (i / (cells - 1)) * 0.5;
return Array.from({ length: cells }, (_, j) => {
const accuracy = 1 - (j / (cells - 1));
const logLoss = calculateLogLoss(confidence, accuracy);
return {
x: (confidence - 0.5) * width * 2, // Scale x to fill width
y: j * height / cells,
confidence,
accuracy,
logLoss,
color: getColor(logLoss)
};
});
}).flat();
const handleMouseMove = (event) => {
const svgRect = event.currentTarget.getBoundingClientRect();
const x = (event.clientX - svgRect.left - margin.left) / width;
const y = 1 - (event.clientY - svgRect.top - margin.top) / height;
if (x >= 0 && x <= 1 && y >= 0 && y <= 1) {
const confidence = 0.5 + x * 0.5;
const accuracy = y;
const logLoss = calculateLogLoss(confidence, accuracy);
setHoverInfo({ confidence, accuracy, logLoss });
} else {
setHoverInfo(null);
}
};
const handleMouseLeave = () => {
setHoverInfo(null);
};
return (
<div className="w-full max-w-4xl">
<h2 className="text-xl font-bold mb-4">Log Loss Heatmap</h2>
<div className="relative">
<svg
width={width + margin.left + margin.right}
height={height + margin.top + margin.bottom}
onMouseMove={handleMouseMove}
onMouseLeave={handleMouseLeave}
>
<g transform={`translate(${margin.left},${margin.top})`}>
{/* Heatmap cells */}
{cells_array.map((cell, i) => (
<rect
key={i}
x={cell.x}
y={cell.y}
width={cellWidth + 0.5}
height={cellHeight + 0.5}
fill={cell.color}
/>
))}
{/* Axes */}
<line x1="0" y1={height} x2={width} y2={height} stroke="black" strokeWidth="2"/>
<line x1="0" y1="0" x2="0" y2={height} stroke="black" strokeWidth="2"/>
{/* X-axis ticks and labels */}
{[0.5, 0.625, 0.75, 0.875, 1].map(tick => (
<g key={tick} transform={`translate(${(tick - 0.5) * width * 2},${height})`}>
<line y2="6" stroke="black"/>
<text y="20" textAnchor="middle">{tick.toFixed(3)}</text>
</g>
))}
{/* Y-axis ticks and labels */}
{[0, 0.25, 0.5, 0.75, 1].map(tick => (
<g key={tick} transform={`translate(0,${height - tick * height})`}>
<line x2="-6" stroke="black"/>
<text x="-10" dy="0.32em" textAnchor="end">{tick.toFixed(2)}</text>
</g>
))}
{/* Axis labels */}
<text
x={width/2}
y={height + 40}
textAnchor="middle"
>Prediction Confidence</text>
<text
transform={`translate(-40,${height/2}) rotate(-90)`}
textAnchor="middle"
>Actual Accuracy</text>
{/* Color scale legend */}
<defs>
<linearGradient id="legendGradient" x1="0" x2="0" y1="0" y2="1">
<stop offset="0%" stopColor="rgb(253,231,37)"/>
<stop offset="25%" stopColor="rgb(94,201,98)"/>
<stop offset="50%" stopColor="rgb(32,144,141)"/>
<stop offset="75%" stopColor="rgb(59,82,139)"/>
<stop offset="100%" stopColor="rgb(68,1,84)"/>
</linearGradient>
</defs>
<rect
x={width + 40}
y={0}
width="20"
height={height}
fill="url(#legendGradient)"
/>
<text
x={width + 70}
y={0}
dy="0.32em"
textAnchor="start"
>0.0</text>
<text
x={width + 70}
y={height}
dy="0.32em"
textAnchor="start"
>4.0+</text>
<text
transform={`translate(${width + 90},${height/2}) rotate(-90)`}
textAnchor="middle"
>Log Loss</text>
</g>
</svg>
{/* Hover information */}
{hoverInfo && (
<div className="absolute bg-white p-2 rounded shadow border border-gray-200">
<p className="text-sm">Confidence: {hoverInfo.confidence.toFixed(3)}</p>
<p className="text-sm">Accuracy: {hoverInfo.accuracy.toFixed(3)}</p>
<p className="text-sm">Log Loss: {hoverInfo.logLoss.toFixed(3)}</p>
</div>
)}
</div>
<div className="mt-4 text-sm text-gray-600">
<p>Lower log loss (yellow) indicates better calibrated predictions</p>
<p>Higher log loss (purple) indicates poorly calibrated predictions</p>
</div>
</div>
);
};
export default LogLossHeatmap;
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment