Created
November 12, 2024 15:12
-
-
Save eggsyntax/e0ea4b3299f0c597e83a663f8cdd5fbf to your computer and use it in GitHub Desktop.
Brier scores heatmap
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import React, { useState } from 'react'; | |
const BrierScoreHeatmap = () => { | |
const [hoverInfo, setHoverInfo] = useState(null); | |
// Calculate Brier score for a predictor who predicts with confidence p | |
// and has accuracy a | |
const calculateBrierScore = (confidence, accuracy) => { | |
// For correct predictions (accuracy proportion of cases): | |
// (confidence - 1)^2 | |
// For incorrect predictions (1-accuracy proportion of cases): | |
// (confidence - 0)^2 | |
return accuracy * Math.pow(confidence - 1, 2) + | |
(1 - accuracy) * Math.pow(confidence - 0, 2); | |
}; | |
// Convert Brier score to color using a viridis-like scale | |
const getColor = (brierScore) => { | |
const factor = Math.min(1, Math.max(0, brierScore)); // clamp to [0,1] | |
// Define key colors in the scale | |
const colors = [ | |
[253, 231, 37], // Yellow | |
[94, 201, 98], // Green | |
[32, 144, 141], // Teal | |
[59, 82, 139], // Blue | |
[68, 1, 84] // Purple | |
]; | |
// Find the right color segment | |
const numSegments = colors.length - 1; | |
const segment = Math.min(Math.floor(factor * numSegments), numSegments - 1); | |
const segmentFactor = (factor * numSegments) - segment; | |
// Interpolate between segment colors | |
const c1 = colors[segment]; | |
const c2 = colors[segment + 1]; | |
const red = Math.round(c1[0] + (c2[0] - c1[0]) * segmentFactor); | |
const green = Math.round(c1[1] + (c2[1] - c1[1]) * segmentFactor); | |
const blue = Math.round(c1[2] + (c2[2] - c1[2]) * segmentFactor); | |
return `rgb(${red},${green},${blue})`; | |
}; | |
const width = 600; | |
const height = 450; | |
const margin = { top: 50, right: 100, bottom: 60, left: 60 }; | |
const cells = 200; // High resolution for smooth gradient | |
const cellWidth = width / cells; | |
const cellHeight = height / cells; | |
// Generate cells for the heatmap | |
const cells_array = Array.from({ length: cells }, (_, i) => { | |
// Confidence ranges from 0.5 to 1.0 | |
const confidence = 0.5 + (i / (cells - 1)) * 0.5; | |
return Array.from({ length: cells }, (_, j) => { | |
const accuracy = 1 - (j / (cells - 1)); | |
const brierScore = calculateBrierScore(confidence, accuracy); | |
return { | |
x: (confidence - 0.5) * width * 2, // Scale x to fill width | |
y: j * height / cells, | |
confidence, | |
accuracy, | |
brierScore, | |
color: getColor(brierScore) | |
}; | |
}); | |
}).flat(); | |
const handleMouseMove = (event) => { | |
const svgRect = event.currentTarget.getBoundingClientRect(); | |
const x = (event.clientX - svgRect.left - margin.left) / width; | |
const y = 1 - (event.clientY - svgRect.top - margin.top) / height; | |
if (x >= 0 && x <= 1 && y >= 0 && y <= 1) { | |
const confidence = 0.5 + x * 0.5; | |
const accuracy = y; | |
const brierScore = calculateBrierScore(confidence, accuracy); | |
setHoverInfo({ confidence, accuracy, brierScore }); | |
} else { | |
setHoverInfo(null); | |
} | |
}; | |
const handleMouseLeave = () => { | |
setHoverInfo(null); | |
}; | |
return ( | |
<div className="w-full max-w-4xl"> | |
<h2 className="text-xl font-bold mb-4">Brier Score Heatmap</h2> | |
<div className="relative"> | |
<svg | |
width={width + margin.left + margin.right} | |
height={height + margin.top + margin.bottom} | |
onMouseMove={handleMouseMove} | |
onMouseLeave={handleMouseLeave} | |
> | |
<g transform={`translate(${margin.left},${margin.top})`}> | |
{/* Heatmap cells */} | |
{cells_array.map((cell, i) => ( | |
<rect | |
key={i} | |
x={cell.x} | |
y={cell.y} | |
width={cellWidth + 0.5} | |
height={cellHeight + 0.5} | |
fill={cell.color} | |
/> | |
))} | |
{/* Axes */} | |
<line x1="0" y1={height} x2={width} y2={height} stroke="black" strokeWidth="2"/> | |
<line x1="0" y1="0" x2="0" y2={height} stroke="black" strokeWidth="2"/> | |
{/* X-axis ticks and labels */} | |
{[0.5, 0.625, 0.75, 0.875, 1].map(tick => ( | |
<g key={tick} transform={`translate(${(tick - 0.5) * width * 2},${height})`}> | |
<line y2="6" stroke="black"/> | |
<text y="20" textAnchor="middle">{tick.toFixed(3)}</text> | |
</g> | |
))} | |
{/* Y-axis ticks and labels */} | |
{[0, 0.25, 0.5, 0.75, 1].map(tick => ( | |
<g key={tick} transform={`translate(0,${height - tick * height})`}> | |
<line x2="-6" stroke="black"/> | |
<text x="-10" dy="0.32em" textAnchor="end">{tick.toFixed(2)}</text> | |
</g> | |
))} | |
{/* Axis labels */} | |
<text | |
x={width/2} | |
y={height + 40} | |
textAnchor="middle" | |
>Prediction Confidence</text> | |
<text | |
transform={`translate(-40,${height/2}) rotate(-90)`} | |
textAnchor="middle" | |
>Actual Accuracy</text> | |
{/* Color scale legend */} | |
<defs> | |
<linearGradient id="legendGradient" x1="0" x2="0" y1="0" y2="1"> | |
<stop offset="0%" stopColor="rgb(255,255,0)"/> | |
<stop offset="100%" stopColor="rgb(0,0,255)"/> | |
</linearGradient> | |
</defs> | |
<rect | |
x={width + 40} | |
y={0} | |
width="20" | |
height={height} | |
fill="url(#legendGradient)" | |
/> | |
<text | |
x={width + 70} | |
y={0} | |
dy="0.32em" | |
textAnchor="start" | |
>0.0</text> | |
<text | |
x={width + 70} | |
y={height} | |
dy="0.32em" | |
textAnchor="start" | |
>1.0</text> | |
<text | |
transform={`translate(${width + 90},${height/2}) rotate(-90)`} | |
textAnchor="middle" | |
>Brier Score</text> | |
</g> | |
</svg> | |
{/* Hover information */} | |
{hoverInfo && ( | |
<div className="absolute bg-white p-2 rounded shadow border border-gray-200"> | |
<p className="text-sm">Confidence: {hoverInfo.confidence.toFixed(3)}</p> | |
<p className="text-sm">Accuracy: {hoverInfo.accuracy.toFixed(3)}</p> | |
<p className="text-sm">Brier Score: {hoverInfo.brierScore.toFixed(3)}</p> | |
</div> | |
)} | |
</div> | |
<div className="mt-4 text-sm text-gray-600"> | |
<p>Lower Brier scores (yellow) indicate better calibrated predictions</p> | |
<p>Higher Brier scores (blue) indicate poorly calibrated predictions</p> | |
</div> | |
</div> | |
); | |
}; | |
export default BrierScoreHeatmap; |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment