From ecb0519646f6255410a58750c9be394fd20b1f03 Mon Sep 17 00:00:00 2001 From: Devon Sawatsky Date: Sun, 14 Jan 2024 19:31:26 -0800 Subject: remove kmeans in favor of NIH'd implementation --- Cargo.lock | 195 +++--------------------------------------------------- Cargo.toml | 2 +- src/lib.rs | 32 +++------ src/nih_kmeans.rs | 59 +++++++++++++++-- 4 files changed, 75 insertions(+), 213 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ab2bacc..7520455 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -14,12 +14,6 @@ version = "1.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "080e9890a082662b09c1ad45f567faeeb47f22b5fb23895fbe1e651e718e25ca" -[[package]] -name = "autocfg" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" - [[package]] name = "bitflags" version = "1.3.2" @@ -67,7 +61,7 @@ name = "colorsquash" version = "0.1.0" dependencies = [ "gifed", - "kmeans", + "rand", "rgb", ] @@ -80,37 +74,6 @@ dependencies = [ "cfg-if", ] -[[package]] -name = "crossbeam-deque" -version = "0.8.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" -dependencies = [ - "crossbeam-epoch", - "crossbeam-utils", -] - -[[package]] -name = "crossbeam-epoch" -version = "0.9.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" -dependencies = [ - "crossbeam-utils", -] - -[[package]] -name = "crossbeam-utils" -version = "0.8.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345" - -[[package]] -name = "either" -version = "1.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" - [[package]] name = "fdeflate" version = "0.3.3" @@ -138,9 +101,9 @@ checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" [[package]] name = "getrandom" -version = "0.1.16" +version = "0.2.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8fc3cb4d91f53b50155bdcfd23f6a4c39ae1969c2ae85982b135750cccaf5fce" +checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5" dependencies = [ "cfg-if", "libc", @@ -155,30 +118,12 @@ dependencies = [ "weezl", ] -[[package]] -name = "kmeans" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76ccc6d18ad4bdf1b31a515991e73192cc1ef9e0ff06ea8ade4d95f80ee70352" -dependencies = [ - "num", - "packed_simd", - "rand", - "rayon", -] - [[package]] name = "libc" version = "0.2.152" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "13e3bf6590cbc649f4d1a3eefc9d5d6eb746f5200ffb04e5e142700b8faa56e7" -[[package]] -name = "libm" -version = "0.2.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" - [[package]] name = "log" version = "0.4.20" @@ -195,93 +140,6 @@ dependencies = [ "simd-adler32", ] -[[package]] -name = "num" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b7a8e9be5e039e2ff869df49155f1c06bd01ade2117ec783e56ab0932b67a8f" -dependencies = [ - "num-bigint", - "num-complex", - "num-integer", - "num-iter", - "num-rational", - "num-traits", -] - -[[package]] -name = "num-bigint" -version = "0.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f6f7833f2cbf2360a6cfd58cd41a53aa7a90bd4c202f5b1c7dd2ed73c57b2c3" -dependencies = [ - "autocfg", - "num-integer", - "num-traits", -] - -[[package]] -name = "num-complex" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "747d632c0c558b87dbabbe6a82f3b4ae03720d0646ac5b7b4dae89394be5f2c5" -dependencies = [ - "num-traits", -] - -[[package]] -name = "num-integer" -version = "0.1.45" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9" -dependencies = [ - "autocfg", - "num-traits", -] - -[[package]] -name = "num-iter" -version = "0.1.43" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d03e6c028c5dc5cac6e2dec0efda81fc887605bb3d884578bb6d6bf7514e252" -dependencies = [ - "autocfg", - "num-integer", - "num-traits", -] - -[[package]] -name = "num-rational" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "12ac428b1cb17fce6f731001d307d351ec70a6d202fc2e60f7d4c5e42d8f4f07" -dependencies = [ - "autocfg", - "num-bigint", - "num-integer", - "num-traits", -] - -[[package]] -name = "num-traits" -version = "0.2.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c" -dependencies = [ - "autocfg", - "libm", -] - -[[package]] -name = "packed_simd" -version = "0.3.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f9f08af0c877571712e2e3e686ad79efad9657dbf0f7c3c8ba943ff6c38932d" -dependencies = [ - "cfg-if", - "num-traits", -] - [[package]] name = "png" version = "0.17.10" @@ -308,22 +166,20 @@ checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" [[package]] name = "rand" -version = "0.7.3" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a6b1679d49b24bbfe0c803429aa1874472f50d9b363131f0e89fc356b544d03" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" dependencies = [ - "getrandom", "libc", "rand_chacha", "rand_core", - "rand_hc", ] [[package]] name = "rand_chacha" -version = "0.2.2" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4c8ed856279c9737206bf725bf36935d8666ead7aa69b52be55af369d193402" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" dependencies = [ "ppv-lite86", "rand_core", @@ -331,42 +187,13 @@ dependencies = [ [[package]] name = "rand_core" -version = "0.5.1" +version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ "getrandom", ] -[[package]] -name = "rand_hc" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c" -dependencies = [ - "rand_core", -] - -[[package]] -name = "rayon" -version = "1.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c27db03db7734835b3f53954b534c91069375ce6ccaa2e065441e07d9b6cdb1" -dependencies = [ - "either", - "rayon-core", -] - -[[package]] -name = "rayon-core" -version = "1.12.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ce3fb6ad83f861aac485e76e1985cd109d9a3713802152be56c3b1f0e0658ed" -dependencies = [ - "crossbeam-deque", - "crossbeam-utils", -] - [[package]] name = "rgb" version = "0.8.37" @@ -402,9 +229,9 @@ checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" [[package]] name = "wasi" -version = "0.9.0+wasi-snapshot-preview1" +version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "weezl" diff --git a/Cargo.toml b/Cargo.toml index 34b1d85..7be3913 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,7 @@ repository = "https://github.com/gennyble/colorsquash" [dependencies] rgb = "0.8.36" gifed = { path = "../gifed/gifed", optional = true } -kmeans = { version = "0.2.1", optional = true } +rand = { version = "0.8.5", optional = true } [workspace] members = ["squash"] diff --git a/src/lib.rs b/src/lib.rs index 40a8f28..ed3d4e2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,9 +1,9 @@ -#[cfg(kmeans)] -use kmeans::{KMeans, KMeansConfig}; +use nih_kmeans::KMeans; use rgb::RGB8; use std::collections::HashMap; pub mod difference; +mod nih_kmeans; type DiffFn = dyn Fn(&RGB8, &RGB8) -> f32; @@ -93,34 +93,21 @@ impl Squasher { } /// Create a new palette from the colours in the given image. - #[cfg(not(kmeans))] pub fn recolor(&mut self, image: &[u8]) { let sorted = Self::unique_and_sort(image); let selected = self.select_colors(sorted); self.palette = selected; } - #[cfg(kmeans)] - pub fn recolor(&mut self, image: &[u8]) { + /// Create a new palette from the colours in the given image, using the iterative kmeans algorithm with simplified seeding + pub fn recolor_kmeans(&mut self, image: &[u8], max_iter: usize) { let kmean = KMeans::new( - image.iter().map(|u| *u as f32).collect::>(), - image.len() / 3, - 3, + image + .chunks_exact(3) + .map(|bytes| RGB8::new(bytes[0], bytes[1], bytes[2])) + .collect(), ); - let k = self.max_colours_min1.as_usize() + 1; - let result = - kmean.kmeans_lloyd(k, 100, KMeans::init_kmeanplusplus, &KMeansConfig::default()); - self.palette = result - .centroids - .chunks_exact(3) - .map(|rgb| { - RGB8::new( - rgb[0].round() as u8, - rgb[1].round() as u8, - rgb[2].round() as u8, - ) - }) - .collect(); + self.palette = kmean.get_k_colors(self.max_colours_min1.as_usize() + 1, max_iter); } /// Create a Squasher from parts. Noteably, this leave your palette empty @@ -225,7 +212,6 @@ impl Squasher { /// Pick the colors in the palette from a Vec of colors sorted by number /// of times they occur, high to low. - #[cfg(not(kmeans))] fn select_colors(&self, sorted: Vec) -> Vec { // I made these numbers up #[allow(non_snake_case)] diff --git a/src/nih_kmeans.rs b/src/nih_kmeans.rs index 98a9965..a34d528 100644 --- a/src/nih_kmeans.rs +++ b/src/nih_kmeans.rs @@ -1,3 +1,6 @@ +use std::collections::HashMap; + +#[cfg(rand)] use rand::{prelude::*, seq::index::sample}; use rgb::{RGB, RGB8}; @@ -5,14 +8,40 @@ pub struct KMeans { samples: Vec, } +#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] +struct HashableRGBF { + inner: (u32, u32, u32), +} + +impl From> for HashableRGBF { + fn from(value: RGB) -> Self { + Self { + inner: (value.r.to_bits(), value.g.to_bits(), value.b.to_bits()), + } + } +} + impl KMeans { pub fn new(samples: Vec) -> Self { Self { samples } } pub fn get_k_colors(&self, k: usize, max_iter: usize) -> Vec { let mut centroids = self.get_centroid_seeds_simple(k); + for _ in 0..max_iter { - todo!() + let mut clusters: HashMap> = HashMap::new(); + + for &sample in &self.samples { + let closest_centroid = Self::closest_centroid(¢roids, sample.into()); + clusters + .entry(closest_centroid.into()) + .or_default() + .push(sample); + } + centroids = clusters + .into_values() + .map(|members| vector_avg(&members)) + .collect() } centroids .into_iter() @@ -20,15 +49,19 @@ impl KMeans { .collect() } - /// Uses k-means++ algorithm (https://www.mathworks.com/help/stats/kmeans.html#bueq7aj-5) + /// Picks a point at random (if feature rand is enabled) for the first centroid, then iteratively adds the point furthest away from any centroid + /// A more complex solution is the probabilistic k-means++ algorithm (https://www.mathworks.com/help/stats/kmeans.html#bueq7aj-5) fn get_centroid_seeds_simple(&self, k: usize) -> Vec> { if k >= self.samples.len() { return self.samples.iter().map(|&v| v.into()).collect(); } - let mut rng = thread_rng(); - let mut centroids: Vec> = - vec![self.samples[rng.gen_range(0..self.samples.len())].into()]; + #[cfg(rand)] + let index = thread_rng().gen_range(0..self.samples.len()); + #[cfg(not(rand))] + let index = 0; //lol + + let mut centroids: Vec> = vec![self.samples[index].into()]; while centroids.len() < k { let next = *self .samples @@ -58,6 +91,7 @@ impl KMeans { .unwrap() } + #[cfg(rand)] fn get_centroid_seeds_random(&self, k: usize) -> Vec> { if k >= self.samples.len() { return self.samples.iter().map(|&v| v.into()).collect(); @@ -78,3 +112,18 @@ fn vector_diff_2_norm(v1: RGB, v2: RGB) -> f32 { let diff = vector_diff(v1, v2); (diff.r.powi(2) + diff.g.powi(2) + diff.b.powi(2)).sqrt() } + +fn vector_sum(acc: RGB, elem: RGB) -> RGB { + RGB::new(acc.r + elem.r, acc.g + elem.g, acc.b + elem.b) +} + +fn vector_avg(vs: &[RGB8]) -> RGB { + let summed = vs.iter().fold(RGB::new(0.0, 0.0, 0.0), |acc, elem| { + vector_sum(acc, (*elem).into()) + }); + RGB::new( + summed.r / vs.len() as f32, + summed.g / vs.len() as f32, + summed.b / vs.len() as f32, + ) +} -- cgit 1.4.1-3-g733a5