title: KNN notes ----
Euclidean Distance (n-dimensional) \(d({p}, {q}) = \sqrt{\sum_{i=1}^{n} (q_i - p_i)^2}\)
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100use serde::Deserialize;
use std::collections::HashMap; // For counting occurrences of labels
use std::error::Error;
// Structure representing an Iris flower with four traits (n=4)
#[derive(Debug, Deserialize)]
struct Iris {
sepal_length: f64,
sepal_width: f64,
petal_length: f64,
petal_width: f64,
species: String,
}
// Read dataset from CSV file
fn read_data_from_csv(file_path: &str) -> Result<Vec<Iris>, Box<dyn Error>> {
let mut rdr = csv::Reader::from_path(file_path)?;
let mut data = Vec::new();
for result in rdr.deserialize() {
let record: Iris = result?;
data.push(record);
}
Ok(data)
}
// Calculate Euclidean distance between two Iris flowers
fn euclidean_distance(e1: &Iris, e2: &Iris) -> f64 {
((e1.sepal_length - e2.sepal_length).powi(2)
+ (e1.sepal_width - e2.sepal_width).powi(2)
+ (e1.petal_length - e2.petal_length).powi(2)
+ (e1.petal_width - e2.petal_width).powi(2))
.sqrt()
}
// KNN algorithm to classify a new Iris flower
fn classify_knn(data: &[Iris], k: usize, new_entity: &Iris) -> String {
// Create a vector of distances between `new_entity` and all entities in `data`
let mut distances: Vec<(f64, &Iris)> = data
.iter()
.map(|entity| (euclidean_distance(entity, new_entity), entity))
.collect();
// Sort distances in ascending order
distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
// Take the labels of the k nearest entities
let mut label_counts = HashMap::new();
for i in 0..k {
let label = &distances[i].1.species;
let count = label_counts.entry(label).or_insert(0);
*count += 1;
}
// Return the label with the highest count
label_counts
.into_iter()
.max_by_key(|&(_, count)| count)
.unwrap()
.0
.clone()
}
fn main() -> Result<(), Box<dyn Error>> {
let file_path: &str = "../datastore/iris_data.csv";
// Read data from CSV
let data = read_data_from_csv(file_path)?;
// Define a new Iris flower to classify
let new_entity = Iris {
sepal_length: 5.1,
sepal_width: 3.5,
petal_length: 1.4,
petal_width: 0.2,
species: String::new(), // species is unknown
};
// Classify the new Iris flower using KNN with k=3
let k = 3;
let predicted_label = classify_knn(&data, k, &new_entity);
println!(
"Classified label for Iris flower (Sepal Length: {}, Sepal Width: {}, Petal Length: {}, Petal Width: {}) is: {}",
new_entity.sepal_length, new_entity.sepal_width, new_entity.petal_length, new_entity.petal_width, predicted_label
);
Ok(())
}