about summary refs log tree commit diff
path: root/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/lib.rs')
-rw-r--r--src/lib.rs32
1 files changed, 9 insertions, 23 deletions
diff --git a/src/lib.rs b/src/lib.rs
index 40a8f28..ed3d4e2 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -1,9 +1,9 @@
-#[cfg(kmeans)]
-use kmeans::{KMeans, KMeansConfig};
+use nih_kmeans::KMeans;
 use rgb::RGB8;
 use std::collections::HashMap;
 
 pub mod difference;
+mod nih_kmeans;
 
 type DiffFn = dyn Fn(&RGB8, &RGB8) -> f32;
 
@@ -93,34 +93,21 @@ impl<T: Count> Squasher<T> {
 	}
 
 	/// Create a new palette from the colours in the given image.
-	#[cfg(not(kmeans))]
 	pub fn recolor(&mut self, image: &[u8]) {
 		let sorted = Self::unique_and_sort(image);
 		let selected = self.select_colors(sorted);
 		self.palette = selected;
 	}
 
-	#[cfg(kmeans)]
-	pub fn recolor(&mut self, image: &[u8]) {
+	/// Create a new palette from the colours in the given image, using the iterative kmeans algorithm with simplified seeding
+	pub fn recolor_kmeans(&mut self, image: &[u8], max_iter: usize) {
 		let kmean = KMeans::new(
-			image.iter().map(|u| *u as f32).collect::<Vec<f32>>(),
-			image.len() / 3,
-			3,
+			image
+				.chunks_exact(3)
+				.map(|bytes| RGB8::new(bytes[0], bytes[1], bytes[2]))
+				.collect(),
 		);
-		let k = self.max_colours_min1.as_usize() + 1;
-		let result =
-			kmean.kmeans_lloyd(k, 100, KMeans::init_kmeanplusplus, &KMeansConfig::default());
-		self.palette = result
-			.centroids
-			.chunks_exact(3)
-			.map(|rgb| {
-				RGB8::new(
-					rgb[0].round() as u8,
-					rgb[1].round() as u8,
-					rgb[2].round() as u8,
-				)
-			})
-			.collect();
+		self.palette = kmean.get_k_colors(self.max_colours_min1.as_usize() + 1, max_iter);
 	}
 
 	/// Create a Squasher from parts. Noteably, this leave your palette empty
@@ -225,7 +212,6 @@ impl<T: Count> Squasher<T> {
 
 	/// Pick the colors in the palette from a Vec of colors sorted by number
 	/// of times they occur, high to low.
-	#[cfg(not(kmeans))]
 	fn select_colors(&self, sorted: Vec<RGB8>) -> Vec<RGB8> {
 		// I made these numbers up
 		#[allow(non_snake_case)]