about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--Cargo.toml2
-rw-r--r--src/lib.rs27
2 files changed, 28 insertions, 1 deletions
diff --git a/Cargo.toml b/Cargo.toml
index f8a9b05..34b1d85 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -12,7 +12,7 @@ repository = "https://github.com/gennyble/colorsquash"
 [dependencies]
 rgb = "0.8.36"
 gifed = { path = "../gifed/gifed", optional = true }
-kmeans = "0.2.1"
+kmeans = { version = "0.2.1", optional = true }
 
 [workspace]
 members = ["squash"]
diff --git a/src/lib.rs b/src/lib.rs
index 213adfc..40a8f28 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -1,3 +1,5 @@
+#[cfg(kmeans)]
+use kmeans::{KMeans, KMeansConfig};
 use rgb::RGB8;
 use std::collections::HashMap;
 
@@ -91,12 +93,36 @@ impl<T: Count> Squasher<T> {
 	}
 
 	/// Create a new palette from the colours in the given image.
+	#[cfg(not(kmeans))]
 	pub fn recolor(&mut self, image: &[u8]) {
 		let sorted = Self::unique_and_sort(image);
 		let selected = self.select_colors(sorted);
 		self.palette = selected;
 	}
 
+	#[cfg(kmeans)]
+	pub fn recolor(&mut self, image: &[u8]) {
+		let kmean = KMeans::new(
+			image.iter().map(|u| *u as f32).collect::<Vec<f32>>(),
+			image.len() / 3,
+			3,
+		);
+		let k = self.max_colours_min1.as_usize() + 1;
+		let result =
+			kmean.kmeans_lloyd(k, 100, KMeans::init_kmeanplusplus, &KMeansConfig::default());
+		self.palette = result
+			.centroids
+			.chunks_exact(3)
+			.map(|rgb| {
+				RGB8::new(
+					rgb[0].round() as u8,
+					rgb[1].round() as u8,
+					rgb[2].round() as u8,
+				)
+			})
+			.collect();
+	}
+
 	/// Create a Squasher from parts. Noteably, this leave your palette empty
 	fn from_parts(max_colours_min1: T, difference_fn: Box<DiffFn>, tolerance: f32) -> Self {
 		Self {
@@ -199,6 +225,7 @@ impl<T: Count> Squasher<T> {
 
 	/// Pick the colors in the palette from a Vec of colors sorted by number
 	/// of times they occur, high to low.
+	#[cfg(not(kmeans))]
 	fn select_colors(&self, sorted: Vec<RGB8>) -> Vec<RGB8> {
 		// I made these numbers up
 		#[allow(non_snake_case)]