about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--Cargo.lock195
-rw-r--r--Cargo.toml2
-rw-r--r--src/lib.rs1
-rw-r--r--src/nih_kmeans.rs129
-rw-r--r--src/selection.rs31
5 files changed, 145 insertions, 213 deletions
diff --git a/Cargo.lock b/Cargo.lock
index ab2bacc..7520455 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -15,12 +15,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "080e9890a082662b09c1ad45f567faeeb47f22b5fb23895fbe1e651e718e25ca"
 
 [[package]]
-name = "autocfg"
-version = "1.1.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
-
-[[package]]
 name = "bitflags"
 version = "1.3.2"
 source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -67,7 +61,7 @@ name = "colorsquash"
 version = "0.1.0"
 dependencies = [
  "gifed",
- "kmeans",
+ "rand",
  "rgb",
 ]
 
@@ -81,37 +75,6 @@ dependencies = [
 ]
 
 [[package]]
-name = "crossbeam-deque"
-version = "0.8.5"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d"
-dependencies = [
- "crossbeam-epoch",
- "crossbeam-utils",
-]
-
-[[package]]
-name = "crossbeam-epoch"
-version = "0.9.18"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e"
-dependencies = [
- "crossbeam-utils",
-]
-
-[[package]]
-name = "crossbeam-utils"
-version = "0.8.19"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345"
-
-[[package]]
-name = "either"
-version = "1.9.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07"
-
-[[package]]
 name = "fdeflate"
 version = "0.3.3"
 source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -138,9 +101,9 @@ checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c"
 
 [[package]]
 name = "getrandom"
-version = "0.1.16"
+version = "0.2.12"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "8fc3cb4d91f53b50155bdcfd23f6a4c39ae1969c2ae85982b135750cccaf5fce"
+checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5"
 dependencies = [
  "cfg-if",
  "libc",
@@ -156,30 +119,12 @@ dependencies = [
 ]
 
 [[package]]
-name = "kmeans"
-version = "0.2.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "76ccc6d18ad4bdf1b31a515991e73192cc1ef9e0ff06ea8ade4d95f80ee70352"
-dependencies = [
- "num",
- "packed_simd",
- "rand",
- "rayon",
-]
-
-[[package]]
 name = "libc"
 version = "0.2.152"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "13e3bf6590cbc649f4d1a3eefc9d5d6eb746f5200ffb04e5e142700b8faa56e7"
 
 [[package]]
-name = "libm"
-version = "0.2.8"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058"
-
-[[package]]
 name = "log"
 version = "0.4.20"
 source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -196,93 +141,6 @@ dependencies = [
 ]
 
 [[package]]
-name = "num"
-version = "0.3.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "8b7a8e9be5e039e2ff869df49155f1c06bd01ade2117ec783e56ab0932b67a8f"
-dependencies = [
- "num-bigint",
- "num-complex",
- "num-integer",
- "num-iter",
- "num-rational",
- "num-traits",
-]
-
-[[package]]
-name = "num-bigint"
-version = "0.3.3"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "5f6f7833f2cbf2360a6cfd58cd41a53aa7a90bd4c202f5b1c7dd2ed73c57b2c3"
-dependencies = [
- "autocfg",
- "num-integer",
- "num-traits",
-]
-
-[[package]]
-name = "num-complex"
-version = "0.3.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "747d632c0c558b87dbabbe6a82f3b4ae03720d0646ac5b7b4dae89394be5f2c5"
-dependencies = [
- "num-traits",
-]
-
-[[package]]
-name = "num-integer"
-version = "0.1.45"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9"
-dependencies = [
- "autocfg",
- "num-traits",
-]
-
-[[package]]
-name = "num-iter"
-version = "0.1.43"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7d03e6c028c5dc5cac6e2dec0efda81fc887605bb3d884578bb6d6bf7514e252"
-dependencies = [
- "autocfg",
- "num-integer",
- "num-traits",
-]
-
-[[package]]
-name = "num-rational"
-version = "0.3.2"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "12ac428b1cb17fce6f731001d307d351ec70a6d202fc2e60f7d4c5e42d8f4f07"
-dependencies = [
- "autocfg",
- "num-bigint",
- "num-integer",
- "num-traits",
-]
-
-[[package]]
-name = "num-traits"
-version = "0.2.17"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c"
-dependencies = [
- "autocfg",
- "libm",
-]
-
-[[package]]
-name = "packed_simd"
-version = "0.3.9"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "1f9f08af0c877571712e2e3e686ad79efad9657dbf0f7c3c8ba943ff6c38932d"
-dependencies = [
- "cfg-if",
- "num-traits",
-]
-
-[[package]]
 name = "png"
 version = "0.17.10"
 source = "git+https://github.com/image-rs/image-png.git?rev=f10238a1e886b228e7da5301e5c0f5011316f2d6#f10238a1e886b228e7da5301e5c0f5011316f2d6"
@@ -308,22 +166,20 @@ checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09"
 
 [[package]]
 name = "rand"
-version = "0.7.3"
+version = "0.8.5"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "6a6b1679d49b24bbfe0c803429aa1874472f50d9b363131f0e89fc356b544d03"
+checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
 dependencies = [
- "getrandom",
  "libc",
  "rand_chacha",
  "rand_core",
- "rand_hc",
 ]
 
 [[package]]
 name = "rand_chacha"
-version = "0.2.2"
+version = "0.3.1"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "f4c8ed856279c9737206bf725bf36935d8666ead7aa69b52be55af369d193402"
+checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
 dependencies = [
  "ppv-lite86",
  "rand_core",
@@ -331,43 +187,14 @@ dependencies = [
 
 [[package]]
 name = "rand_core"
-version = "0.5.1"
+version = "0.6.4"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "90bde5296fc891b0cef12a6d03ddccc162ce7b2aff54160af9338f8d40df6d19"
+checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
 dependencies = [
  "getrandom",
 ]
 
 [[package]]
-name = "rand_hc"
-version = "0.2.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ca3129af7b92a17112d59ad498c6f81eaf463253766b90396d39ea7a39d6613c"
-dependencies = [
- "rand_core",
-]
-
-[[package]]
-name = "rayon"
-version = "1.8.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "9c27db03db7734835b3f53954b534c91069375ce6ccaa2e065441e07d9b6cdb1"
-dependencies = [
- "either",
- "rayon-core",
-]
-
-[[package]]
-name = "rayon-core"
-version = "1.12.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "5ce3fb6ad83f861aac485e76e1985cd109d9a3713802152be56c3b1f0e0658ed"
-dependencies = [
- "crossbeam-deque",
- "crossbeam-utils",
-]
-
-[[package]]
 name = "rgb"
 version = "0.8.37"
 source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -402,9 +229,9 @@ checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369"
 
 [[package]]
 name = "wasi"
-version = "0.9.0+wasi-snapshot-preview1"
+version = "0.11.0+wasi-snapshot-preview1"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "cccddf32554fecc6acb585f82a32a72e28b48f8c4c1883ddfeeeaa96f7d8e519"
+checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
 
 [[package]]
 name = "weezl"
diff --git a/Cargo.toml b/Cargo.toml
index a877168..e0cdfc4 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -12,7 +12,7 @@ repository = "https://github.com/gennyble/colorsquash"
 [dependencies]
 rgb = "0.8.36"
 gifed = { path = "../gifed/gifed", optional = true }
-kmeans = { version = "0.2.1", optional = true }
+rand = { version = "0.8.5", optional = true }
 
 [workspace]
 members = ["squash"]
diff --git a/src/lib.rs b/src/lib.rs
index 31a4641..9ab6b5c 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -3,6 +3,7 @@ use std::collections::HashSet;
 use rgb::{ComponentBytes, FromSlice, RGB8};
 
 pub mod difference;
+mod nih_kmeans;
 pub mod selection;
 
 use difference::DiffFn;
diff --git a/src/nih_kmeans.rs b/src/nih_kmeans.rs
new file mode 100644
index 0000000..a34d528
--- /dev/null
+++ b/src/nih_kmeans.rs
@@ -0,0 +1,129 @@
+use std::collections::HashMap;
+
+#[cfg(rand)]
+use rand::{prelude::*, seq::index::sample};
+use rgb::{RGB, RGB8};
+
+pub struct KMeans {
+	samples: Vec<RGB8>,
+}
+
+#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
+struct HashableRGBF {
+	inner: (u32, u32, u32),
+}
+
+impl From<RGB<f32>> for HashableRGBF {
+	fn from(value: RGB<f32>) -> Self {
+		Self {
+			inner: (value.r.to_bits(), value.g.to_bits(), value.b.to_bits()),
+		}
+	}
+}
+
+impl KMeans {
+	pub fn new(samples: Vec<RGB8>) -> Self {
+		Self { samples }
+	}
+	pub fn get_k_colors(&self, k: usize, max_iter: usize) -> Vec<RGB8> {
+		let mut centroids = self.get_centroid_seeds_simple(k);
+
+		for _ in 0..max_iter {
+			let mut clusters: HashMap<HashableRGBF, Vec<RGB8>> = HashMap::new();
+
+			for &sample in &self.samples {
+				let closest_centroid = Self::closest_centroid(&centroids, sample.into());
+				clusters
+					.entry(closest_centroid.into())
+					.or_default()
+					.push(sample);
+			}
+			centroids = clusters
+				.into_values()
+				.map(|members| vector_avg(&members))
+				.collect()
+		}
+		centroids
+			.into_iter()
+			.map(|c| RGB8::new(c.r.round() as u8, c.g.round() as u8, c.b.round() as u8))
+			.collect()
+	}
+
+	/// Picks a point at random (if feature rand is enabled) for the first centroid, then iteratively adds the point furthest away from any centroid
+	/// A more complex solution is the probabilistic k-means++ algorithm (https://www.mathworks.com/help/stats/kmeans.html#bueq7aj-5)
+	fn get_centroid_seeds_simple(&self, k: usize) -> Vec<RGB<f32>> {
+		if k >= self.samples.len() {
+			return self.samples.iter().map(|&v| v.into()).collect();
+		}
+
+		#[cfg(rand)]
+		let index = thread_rng().gen_range(0..self.samples.len());
+		#[cfg(not(rand))]
+		let index = 0; //lol
+
+		let mut centroids: Vec<RGB<f32>> = vec![self.samples[index].into()];
+		while centroids.len() < k {
+			let next = *self
+				.samples
+				.iter()
+				.max_by(|&&v1, &&v2| {
+					let v1_closest_centroid = Self::closest_centroid(&centroids, v1.into());
+					let v2_closest_centroid = Self::closest_centroid(&centroids, v2.into());
+
+					vector_diff_2_norm(v1.into(), v1_closest_centroid)
+						.partial_cmp(&vector_diff_2_norm(v2.into(), v2_closest_centroid))
+						.unwrap()
+				})
+				.unwrap();
+			centroids.push(next.into());
+		}
+		centroids
+	}
+
+	fn closest_centroid(centroids: &[RGB<f32>], v: RGB<f32>) -> RGB<f32> {
+		*centroids
+			.iter()
+			.min_by(|&&c1, &&c2| {
+				vector_diff_2_norm(c1, v)
+					.partial_cmp(&vector_diff_2_norm(c2, v))
+					.unwrap()
+			})
+			.unwrap()
+	}
+
+	#[cfg(rand)]
+	fn get_centroid_seeds_random(&self, k: usize) -> Vec<RGB<f32>> {
+		if k >= self.samples.len() {
+			return self.samples.iter().map(|&v| v.into()).collect();
+		}
+
+		sample(&mut thread_rng(), self.samples.len(), k)
+			.into_iter()
+			.map(|i| self.samples[i].into())
+			.collect()
+	}
+}
+
+fn vector_diff(v1: RGB<f32>, v2: RGB<f32>) -> RGB<f32> {
+	RGB::new(v1.r - v2.r, v1.g - v2.g, v1.b - v2.b)
+}
+
+fn vector_diff_2_norm(v1: RGB<f32>, v2: RGB<f32>) -> f32 {
+	let diff = vector_diff(v1, v2);
+	(diff.r.powi(2) + diff.g.powi(2) + diff.b.powi(2)).sqrt()
+}
+
+fn vector_sum(acc: RGB<f32>, elem: RGB<f32>) -> RGB<f32> {
+	RGB::new(acc.r + elem.r, acc.g + elem.g, acc.b + elem.b)
+}
+
+fn vector_avg(vs: &[RGB8]) -> RGB<f32> {
+	let summed = vs.iter().fold(RGB::new(0.0, 0.0, 0.0), |acc, elem| {
+		vector_sum(acc, (*elem).into())
+	});
+	RGB::new(
+		summed.r / vs.len() as f32,
+		summed.g / vs.len() as f32,
+		summed.b / vs.len() as f32,
+	)
+}
diff --git a/src/selection.rs b/src/selection.rs
index 1e27ac4..dacd735 100644
--- a/src/selection.rs
+++ b/src/selection.rs
@@ -1,7 +1,7 @@
 use std::collections::HashMap;
 
 #[cfg(feature = "kmeans")]
-use kmeans::{KMeans, KMeansConfig};
+use crate::nih_kmeans::KMeans;
 use rgb::{ComponentBytes, RGB8};
 
 use crate::{
@@ -112,32 +112,7 @@ impl Selector for Kmeans {
 	fn select(&mut self, max_colors: usize, image: ImageData) -> Vec<RGB8> {
 		let ImageData(rgb) = image;
 
-		let kmean = KMeans::new(
-			rgb.as_bytes()
-				.iter()
-				.map(|u| *u as f32)
-				.collect::<Vec<f32>>(),
-			rgb.as_bytes().len() / 3,
-			3,
-		);
-
-		let result = kmean.kmeans_lloyd(
-			max_colors,
-			100,
-			KMeans::init_kmeanplusplus,
-			&KMeansConfig::default(),
-		);
-
-		result
-			.centroids
-			.chunks_exact(3)
-			.map(|rgb| {
-				RGB8::new(
-					rgb[0].round() as u8,
-					rgb[1].round() as u8,
-					rgb[2].round() as u8,
-				)
-			})
-			.collect()
+		let kmean = KMeans::new(rgb.to_vec());
+		kmean.get_k_colors(max_colors, max_iter)
 	}
 }