Created
November 12, 2024 15:24
-
-
Save eggsyntax/6acd209f7bbc952ce7ede348b349e24a to your computer and use it in GitHub Desktop.
Log-loss heatmap
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import React, { useState } from 'react'; | |
const LogLossHeatmap = () => { | |
const [hoverInfo, setHoverInfo] = useState(null); | |
// Calculate log loss for a predictor who predicts with confidence p | |
// and has accuracy a | |
const calculateLogLoss = (confidence, accuracy) => { | |
// Handle edge cases to avoid infinity/NaN | |
if (confidence === 1.0) { | |
if (accuracy === 1.0) return 0; | |
if (accuracy === 0.0) return 4; // Cap at our max value | |
} | |
if (confidence === 0.5 && accuracy === 1.0) return Math.log(2); | |
// For correct predictions: -log(confidence) | |
// For incorrect predictions: -log(1-confidence) | |
const result = -(accuracy * Math.log(confidence) + | |
(1 - accuracy) * Math.log(1 - confidence)); | |
// Cap at 4 for visualization purposes | |
return Math.min(4, result); | |
}; | |
// Convert log loss to color using a viridis-like scale | |
const getColor = (score) => { | |
// Log loss of 4 is already quite bad (predicting with high confidence and being wrong) | |
const maxScore = 4; | |
const factor = Math.min(1, Math.max(0, score / maxScore)); | |
// Define key colors in the scale | |
const colors = [ | |
[253, 231, 37], // Yellow | |
[94, 201, 98], // Green | |
[32, 144, 141], // Teal | |
[59, 82, 139], // Blue | |
[68, 1, 84] // Purple | |
]; | |
const numSegments = colors.length - 1; | |
const segment = Math.min(Math.floor(factor * numSegments), numSegments - 1); | |
const segmentFactor = (factor * numSegments) - segment; | |
const c1 = colors[segment]; | |
const c2 = colors[segment + 1]; | |
const red = Math.round(c1[0] + (c2[0] - c1[0]) * segmentFactor); | |
const green = Math.round(c1[1] + (c2[1] - c1[1]) * segmentFactor); | |
const blue = Math.round(c1[2] + (c2[2] - c1[2]) * segmentFactor); | |
return `rgb(${red},${green},${blue})`; | |
}; | |
const width = 600; | |
const height = 450; | |
const margin = { top: 50, right: 100, bottom: 60, left: 60 }; | |
const cells = 200; // High resolution for smooth gradient | |
const cellWidth = width / cells; | |
const cellHeight = height / cells; | |
// Generate cells for the heatmap | |
const cells_array = Array.from({ length: cells }, (_, i) => { | |
// Confidence ranges from 0.5 to 1.0 | |
const confidence = 0.5 + (i / (cells - 1)) * 0.5; | |
return Array.from({ length: cells }, (_, j) => { | |
const accuracy = 1 - (j / (cells - 1)); | |
const logLoss = calculateLogLoss(confidence, accuracy); | |
return { | |
x: (confidence - 0.5) * width * 2, // Scale x to fill width | |
y: j * height / cells, | |
confidence, | |
accuracy, | |
logLoss, | |
color: getColor(logLoss) | |
}; | |
}); | |
}).flat(); | |
const handleMouseMove = (event) => { | |
const svgRect = event.currentTarget.getBoundingClientRect(); | |
const x = (event.clientX - svgRect.left - margin.left) / width; | |
const y = 1 - (event.clientY - svgRect.top - margin.top) / height; | |
if (x >= 0 && x <= 1 && y >= 0 && y <= 1) { | |
const confidence = 0.5 + x * 0.5; | |
const accuracy = y; | |
const logLoss = calculateLogLoss(confidence, accuracy); | |
setHoverInfo({ confidence, accuracy, logLoss }); | |
} else { | |
setHoverInfo(null); | |
} | |
}; | |
const handleMouseLeave = () => { | |
setHoverInfo(null); | |
}; | |
return ( | |
<div className="w-full max-w-4xl"> | |
<h2 className="text-xl font-bold mb-4">Log Loss Heatmap</h2> | |
<div className="relative"> | |
<svg | |
width={width + margin.left + margin.right} | |
height={height + margin.top + margin.bottom} | |
onMouseMove={handleMouseMove} | |
onMouseLeave={handleMouseLeave} | |
> | |
<g transform={`translate(${margin.left},${margin.top})`}> | |
{/* Heatmap cells */} | |
{cells_array.map((cell, i) => ( | |
<rect | |
key={i} | |
x={cell.x} | |
y={cell.y} | |
width={cellWidth + 0.5} | |
height={cellHeight + 0.5} | |
fill={cell.color} | |
/> | |
))} | |
{/* Axes */} | |
<line x1="0" y1={height} x2={width} y2={height} stroke="black" strokeWidth="2"/> | |
<line x1="0" y1="0" x2="0" y2={height} stroke="black" strokeWidth="2"/> | |
{/* X-axis ticks and labels */} | |
{[0.5, 0.625, 0.75, 0.875, 1].map(tick => ( | |
<g key={tick} transform={`translate(${(tick - 0.5) * width * 2},${height})`}> | |
<line y2="6" stroke="black"/> | |
<text y="20" textAnchor="middle">{tick.toFixed(3)}</text> | |
</g> | |
))} | |
{/* Y-axis ticks and labels */} | |
{[0, 0.25, 0.5, 0.75, 1].map(tick => ( | |
<g key={tick} transform={`translate(0,${height - tick * height})`}> | |
<line x2="-6" stroke="black"/> | |
<text x="-10" dy="0.32em" textAnchor="end">{tick.toFixed(2)}</text> | |
</g> | |
))} | |
{/* Axis labels */} | |
<text | |
x={width/2} | |
y={height + 40} | |
textAnchor="middle" | |
>Prediction Confidence</text> | |
<text | |
transform={`translate(-40,${height/2}) rotate(-90)`} | |
textAnchor="middle" | |
>Actual Accuracy</text> | |
{/* Color scale legend */} | |
<defs> | |
<linearGradient id="legendGradient" x1="0" x2="0" y1="0" y2="1"> | |
<stop offset="0%" stopColor="rgb(253,231,37)"/> | |
<stop offset="25%" stopColor="rgb(94,201,98)"/> | |
<stop offset="50%" stopColor="rgb(32,144,141)"/> | |
<stop offset="75%" stopColor="rgb(59,82,139)"/> | |
<stop offset="100%" stopColor="rgb(68,1,84)"/> | |
</linearGradient> | |
</defs> | |
<rect | |
x={width + 40} | |
y={0} | |
width="20" | |
height={height} | |
fill="url(#legendGradient)" | |
/> | |
<text | |
x={width + 70} | |
y={0} | |
dy="0.32em" | |
textAnchor="start" | |
>0.0</text> | |
<text | |
x={width + 70} | |
y={height} | |
dy="0.32em" | |
textAnchor="start" | |
>4.0+</text> | |
<text | |
transform={`translate(${width + 90},${height/2}) rotate(-90)`} | |
textAnchor="middle" | |
>Log Loss</text> | |
</g> | |
</svg> | |
{/* Hover information */} | |
{hoverInfo && ( | |
<div className="absolute bg-white p-2 rounded shadow border border-gray-200"> | |
<p className="text-sm">Confidence: {hoverInfo.confidence.toFixed(3)}</p> | |
<p className="text-sm">Accuracy: {hoverInfo.accuracy.toFixed(3)}</p> | |
<p className="text-sm">Log Loss: {hoverInfo.logLoss.toFixed(3)}</p> | |
</div> | |
)} | |
</div> | |
<div className="mt-4 text-sm text-gray-600"> | |
<p>Lower log loss (yellow) indicates better calibrated predictions</p> | |
<p>Higher log loss (purple) indicates poorly calibrated predictions</p> | |
</div> | |
</div> | |
); | |
}; | |
export default LogLossHeatmap; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment