From ecb0519646f6255410a58750c9be394fd20b1f03 Mon Sep 17 00:00:00 2001 From: Devon Sawatsky Date: Sun, 14 Jan 2024 19:31:26 -0800 Subject: remove kmeans in favor of NIH'd implementation --- src/lib.rs | 32 +++++++++--------------------- src/nih_kmeans.rs | 59 ++++++++++++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 63 insertions(+), 28 deletions(-) (limited to 'src') diff --git a/src/lib.rs b/src/lib.rs index 40a8f28..ed3d4e2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,9 +1,9 @@ -#[cfg(kmeans)] -use kmeans::{KMeans, KMeansConfig}; +use nih_kmeans::KMeans; use rgb::RGB8; use std::collections::HashMap; pub mod difference; +mod nih_kmeans; type DiffFn = dyn Fn(&RGB8, &RGB8) -> f32; @@ -93,34 +93,21 @@ impl Squasher { } /// Create a new palette from the colours in the given image. - #[cfg(not(kmeans))] pub fn recolor(&mut self, image: &[u8]) { let sorted = Self::unique_and_sort(image); let selected = self.select_colors(sorted); self.palette = selected; } - #[cfg(kmeans)] - pub fn recolor(&mut self, image: &[u8]) { + /// Create a new palette from the colours in the given image, using the iterative kmeans algorithm with simplified seeding + pub fn recolor_kmeans(&mut self, image: &[u8], max_iter: usize) { let kmean = KMeans::new( - image.iter().map(|u| *u as f32).collect::>(), - image.len() / 3, - 3, + image + .chunks_exact(3) + .map(|bytes| RGB8::new(bytes[0], bytes[1], bytes[2])) + .collect(), ); - let k = self.max_colours_min1.as_usize() + 1; - let result = - kmean.kmeans_lloyd(k, 100, KMeans::init_kmeanplusplus, &KMeansConfig::default()); - self.palette = result - .centroids - .chunks_exact(3) - .map(|rgb| { - RGB8::new( - rgb[0].round() as u8, - rgb[1].round() as u8, - rgb[2].round() as u8, - ) - }) - .collect(); + self.palette = kmean.get_k_colors(self.max_colours_min1.as_usize() + 1, max_iter); } /// Create a Squasher from parts. Noteably, this leave your palette empty @@ -225,7 +212,6 @@ impl Squasher { /// Pick the colors in the palette from a Vec of colors sorted by number /// of times they occur, high to low. - #[cfg(not(kmeans))] fn select_colors(&self, sorted: Vec) -> Vec { // I made these numbers up #[allow(non_snake_case)] diff --git a/src/nih_kmeans.rs b/src/nih_kmeans.rs index 98a9965..a34d528 100644 --- a/src/nih_kmeans.rs +++ b/src/nih_kmeans.rs @@ -1,3 +1,6 @@ +use std::collections::HashMap; + +#[cfg(rand)] use rand::{prelude::*, seq::index::sample}; use rgb::{RGB, RGB8}; @@ -5,14 +8,40 @@ pub struct KMeans { samples: Vec, } +#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] +struct HashableRGBF { + inner: (u32, u32, u32), +} + +impl From> for HashableRGBF { + fn from(value: RGB) -> Self { + Self { + inner: (value.r.to_bits(), value.g.to_bits(), value.b.to_bits()), + } + } +} + impl KMeans { pub fn new(samples: Vec) -> Self { Self { samples } } pub fn get_k_colors(&self, k: usize, max_iter: usize) -> Vec { let mut centroids = self.get_centroid_seeds_simple(k); + for _ in 0..max_iter { - todo!() + let mut clusters: HashMap> = HashMap::new(); + + for &sample in &self.samples { + let closest_centroid = Self::closest_centroid(¢roids, sample.into()); + clusters + .entry(closest_centroid.into()) + .or_default() + .push(sample); + } + centroids = clusters + .into_values() + .map(|members| vector_avg(&members)) + .collect() } centroids .into_iter() @@ -20,15 +49,19 @@ impl KMeans { .collect() } - /// Uses k-means++ algorithm (https://www.mathworks.com/help/stats/kmeans.html#bueq7aj-5) + /// Picks a point at random (if feature rand is enabled) for the first centroid, then iteratively adds the point furthest away from any centroid + /// A more complex solution is the probabilistic k-means++ algorithm (https://www.mathworks.com/help/stats/kmeans.html#bueq7aj-5) fn get_centroid_seeds_simple(&self, k: usize) -> Vec> { if k >= self.samples.len() { return self.samples.iter().map(|&v| v.into()).collect(); } - let mut rng = thread_rng(); - let mut centroids: Vec> = - vec![self.samples[rng.gen_range(0..self.samples.len())].into()]; + #[cfg(rand)] + let index = thread_rng().gen_range(0..self.samples.len()); + #[cfg(not(rand))] + let index = 0; //lol + + let mut centroids: Vec> = vec![self.samples[index].into()]; while centroids.len() < k { let next = *self .samples @@ -58,6 +91,7 @@ impl KMeans { .unwrap() } + #[cfg(rand)] fn get_centroid_seeds_random(&self, k: usize) -> Vec> { if k >= self.samples.len() { return self.samples.iter().map(|&v| v.into()).collect(); @@ -78,3 +112,18 @@ fn vector_diff_2_norm(v1: RGB, v2: RGB) -> f32 { let diff = vector_diff(v1, v2); (diff.r.powi(2) + diff.g.powi(2) + diff.b.powi(2)).sqrt() } + +fn vector_sum(acc: RGB, elem: RGB) -> RGB { + RGB::new(acc.r + elem.r, acc.g + elem.g, acc.b + elem.b) +} + +fn vector_avg(vs: &[RGB8]) -> RGB { + let summed = vs.iter().fold(RGB::new(0.0, 0.0, 0.0), |acc, elem| { + vector_sum(acc, (*elem).into()) + }); + RGB::new( + summed.r / vs.len() as f32, + summed.g / vs.len() as f32, + summed.b / vs.len() as f32, + ) +} -- cgit 1.4.1-3-g733a5