about summary refs log tree commit diff
path: root/src
diff options
context:
space:
mode:
authorgennyble <gen@nyble.dev>2024-01-16 15:06:29 -0600
committergennyble <gen@nyble.dev>2024-01-16 15:06:29 -0600
commitc887f5fef803e720052898b9561028dfd50e51db (patch)
tree0ea9c8c9e6c4603d0626bc12e653b6421f16f78a /src
parentbbbac45d835dd40c3fb53f8c7f9a2731783841e8 (diff)
downloadcolorsquash-c887f5fef803e720052898b9561028dfd50e51db.tar.gz
colorsquash-c887f5fef803e720052898b9561028dfd50e51db.zip
ability to choose kmeans implementation
Diffstat (limited to 'src')
-rw-r--r--src/lib.rs1
-rw-r--r--src/nih_kmeans.rs2
-rw-r--r--src/selection.rs52
3 files changed, 48 insertions, 7 deletions
diff --git a/src/lib.rs b/src/lib.rs
index 9ab6b5c..cbecf76 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -3,6 +3,7 @@ use std::collections::HashSet;
 use rgb::{ComponentBytes, FromSlice, RGB8};
 
 pub mod difference;
+#[cfg(not(feature = "simd-kmeans"))]
 mod nih_kmeans;
 pub mod selection;
 
diff --git a/src/nih_kmeans.rs b/src/nih_kmeans.rs
index a34d528..ee752bc 100644
--- a/src/nih_kmeans.rs
+++ b/src/nih_kmeans.rs
@@ -1,6 +1,6 @@
 use std::collections::HashMap;
 
-#[cfg(rand)]
+#[cfg(feature = "rand")]
 use rand::{prelude::*, seq::index::sample};
 use rgb::{RGB, RGB8};
 
diff --git a/src/selection.rs b/src/selection.rs
index dacd735..d93603b 100644
--- a/src/selection.rs
+++ b/src/selection.rs
@@ -1,8 +1,10 @@
 use std::collections::HashMap;
 
-#[cfg(feature = "kmeans")]
+#[cfg(not(feature = "simd-kmeans"))]
 use crate::nih_kmeans::KMeans;
-use rgb::{ComponentBytes, RGB8};
+#[cfg(feature = "simd-kmeans")]
+use kmeans::{KMeans, KMeansConfig};
+use rgb::RGB8;
 
 use crate::{
 	difference::{self, DiffFn},
@@ -103,16 +105,54 @@ impl Default for SortSelect {
 	}
 }
 
-#[cfg(feature = "kmeans")]
 #[derive(Debug, Default)]
-pub struct Kmeans;
+pub struct Kmeans {
+	pub max_iter: usize,
+}
 
-#[cfg(feature = "kmeans")]
+#[cfg(not(feature = "simd-kmeans"))]
 impl Selector for Kmeans {
 	fn select(&mut self, max_colors: usize, image: ImageData) -> Vec<RGB8> {
 		let ImageData(rgb) = image;
 
 		let kmean = KMeans::new(rgb.to_vec());
-		kmean.get_k_colors(max_colors, max_iter)
+		kmean.get_k_colors(max_colors, self.max_iter)
+	}
+}
+
+#[cfg(feature = "simd-kmeans")]
+impl Selector for Kmeans {
+	fn select(&mut self, max_colors: usize, image: ImageData) -> Vec<RGB8> {
+		use rgb::ComponentBytes;
+
+		let ImageData(rgb) = image;
+
+		let kmean = KMeans::new(
+			rgb.as_bytes()
+				.iter()
+				.map(|u| *u as f32)
+				.collect::<Vec<f32>>(),
+			rgb.as_bytes().len() / 3,
+			3,
+		);
+
+		let result = kmean.kmeans_lloyd(
+			max_colors,
+			self.max_iter,
+			KMeans::init_kmeanplusplus,
+			&KMeansConfig::default(),
+		);
+
+		result
+			.centroids
+			.chunks_exact(3)
+			.map(|rgb| {
+				RGB8::new(
+					rgb[0].round() as u8,
+					rgb[1].round() as u8,
+					rgb[2].round() as u8,
+				)
+			})
+			.collect()
 	}
 }