Created
November 15, 2017 06:54
-
-
Save joeld42/f872d393ae7d2b35c4826dac349984b9 to your computer and use it in GitHub Desktop.
Simple brute-force k-means palette quantization
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Slow, bad K-Means image quantization. | |
# This code is in the public domain. | |
import os, sys | |
import random | |
from PIL import Image, ImageDraw | |
# Doesn't seem to get much better after this, YMMV... | |
MAX_ITER = 10 | |
NUM_MEANS = 16 | |
class ColorGroup: | |
def __init__(self): | |
self.meanColor = (0.0,0.0,0.0) | |
self.count = 0 | |
self.totColor = (0.0, 0.0, 0.0) | |
self.lastCount = 0 # NOTE: This is potentially wrong but seems to work OK | |
def roundColor(self): | |
return ( int(round(self.meanColor[0])), int(round(self.meanColor[1])), int(round(self.meanColor[2])) ) | |
def closestColor( colors, target) : | |
bestCol = None | |
bestErr = 0.0 | |
for cg in colors: | |
err = (target[0] - cg.meanColor[0])**2 + (target[1] - cg.meanColor[1])**2 + (target[2] - cg.meanColor[2])**2 | |
if (bestCol is None) or (err < bestErr): | |
bestCol = cg | |
bestErr = err | |
return bestCol | |
if __name__=='__main__': | |
if len(sys.argv) < 2: | |
print "Usage: k-means <input image>" | |
sys.exit(1) | |
infile = sys.argv[1] | |
outbase, outext = os.path.splitext( infile ) | |
outfilename = outbase + "_pal" + str(NUM_MEANS) + outext | |
img = Image.open( infile ) | |
pix = img.load() | |
w,h = img.size | |
unique_color_counts = {} | |
for j in range(h): | |
for i in range (w): | |
c = pix[i,j] | |
unique_color_counts[c] = unique_color_counts.get( c, 0 ) + 1 | |
unique_colors = list(unique_color_counts.keys()) | |
random.shuffle( unique_colors ) | |
print len(unique_colors), " colors in image" | |
means = [] | |
while len(means) < NUM_MEANS: | |
cg = ColorGroup() | |
startCol = unique_colors[len(means)] | |
cg.meanColor = startCol | |
print startCol | |
means.append( cg ) | |
for step in range(MAX_ITER): | |
print "K-Means iter ", step | |
for cg in means: | |
cg.totColor = (0.0, 0.0, 0.0) | |
cg.count = 0 | |
for c in unique_colors: | |
group = closestColor( means, c ) | |
weight = unique_color_counts[ c ] | |
group.totColor = (group.totColor[0] + c[0] * weight, group.totColor[1] + c[1]*weight, group.totColor[2] + c[2]*weight ) | |
group.count += weight | |
converged = True | |
for cg in means: | |
if (cg.count > 0): | |
cg.totColor = ( cg.totColor[0] / cg.count, cg.totColor[1] / cg.count, cg.totColor[2] / cg.count ) | |
cg.meanColor = cg.totColor | |
if cg.lastCount != cg.count: | |
cg.lastCount = cg.count | |
converged = False | |
if converged: | |
print "Converged..." | |
break | |
# Palettize result (still save a RGB image, but just for visualization of result) | |
pix = img.load() | |
for j in range(h-1): | |
for i in range (w-1): | |
bestCol = closestColor( means, pix[i,j] ) | |
pix[i,j] = bestCol.roundColor() | |
# Draw the palette | |
barSz = float(w) / NUM_MEANS | |
draw = ImageDraw.Draw( img ) | |
for i in range(NUM_MEANS): | |
cg = means[i] | |
draw.rectangle( [ int(i * barSz), 0, int((i+1)*barSz), 20 ], fill = cg.roundColor(), outline = (255,255,255) ) | |
img.save( outfilename ) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment