about summary refs log tree commit diff
path: root/src/selection.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/selection.rs')
-rw-r--r--src/selection.rs52
1 files changed, 46 insertions, 6 deletions
diff --git a/src/selection.rs b/src/selection.rs
index dacd735..d93603b 100644
--- a/src/selection.rs
+++ b/src/selection.rs
@@ -1,8 +1,10 @@
 use std::collections::HashMap;
 
-#[cfg(feature = "kmeans")]
+#[cfg(not(feature = "simd-kmeans"))]
 use crate::nih_kmeans::KMeans;
-use rgb::{ComponentBytes, RGB8};
+#[cfg(feature = "simd-kmeans")]
+use kmeans::{KMeans, KMeansConfig};
+use rgb::RGB8;
 
 use crate::{
 	difference::{self, DiffFn},
@@ -103,16 +105,54 @@ impl Default for SortSelect {
 	}
 }
 
-#[cfg(feature = "kmeans")]
 #[derive(Debug, Default)]
-pub struct Kmeans;
+pub struct Kmeans {
+	pub max_iter: usize,
+}
 
-#[cfg(feature = "kmeans")]
+#[cfg(not(feature = "simd-kmeans"))]
 impl Selector for Kmeans {
 	fn select(&mut self, max_colors: usize, image: ImageData) -> Vec<RGB8> {
 		let ImageData(rgb) = image;
 
 		let kmean = KMeans::new(rgb.to_vec());
-		kmean.get_k_colors(max_colors, max_iter)
+		kmean.get_k_colors(max_colors, self.max_iter)
+	}
+}
+
+#[cfg(feature = "simd-kmeans")]
+impl Selector for Kmeans {
+	fn select(&mut self, max_colors: usize, image: ImageData) -> Vec<RGB8> {
+		use rgb::ComponentBytes;
+
+		let ImageData(rgb) = image;
+
+		let kmean = KMeans::new(
+			rgb.as_bytes()
+				.iter()
+				.map(|u| *u as f32)
+				.collect::<Vec<f32>>(),
+			rgb.as_bytes().len() / 3,
+			3,
+		);
+
+		let result = kmean.kmeans_lloyd(
+			max_colors,
+			self.max_iter,
+			KMeans::init_kmeanplusplus,
+			&KMeansConfig::default(),
+		);
+
+		result
+			.centroids
+			.chunks_exact(3)
+			.map(|rgb| {
+				RGB8::new(
+					rgb[0].round() as u8,
+					rgb[1].round() as u8,
+					rgb[2].round() as u8,
+				)
+			})
+			.collect()
 	}
 }