weblog

Sunday 19 May 2024

title: KNN notes ----

Euclidean Distance (n-dimensional) \(d({p}, {q}) = \sqrt{\sum_{i=1}^{n} (q_i - p_i)^2}\)

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100use serde::Deserialize;
use std::collections::HashMap; // For counting occurrences of labels
use std::error::Error; 

// Structure representing an Iris flower with four traits (n=4)
#[derive(Debug, Deserialize)]
struct Iris {
    sepal_length: f64,
    sepal_width: f64,
    petal_length: f64,
    petal_width: f64,
    species: String,
}

// Read dataset from CSV file
fn read_data_from_csv(file_path: &str) -> Result<Vec<Iris>, Box<dyn Error>> {
    let mut rdr = csv::Reader::from_path(file_path)?;
    let mut data = Vec::new();

  for result in rdr.deserialize() {
    let record: Iris = result?;
    data.push(record);
    }

Ok(data)
}

// Calculate Euclidean distance between two Iris flowers
fn euclidean_distance(e1: &Iris, e2: &Iris) -> f64 {
    ((e1.sepal_length - e2.sepal_length).powi(2)
    + (e1.sepal_width - e2.sepal_width).powi(2)
    + (e1.petal_length - e2.petal_length).powi(2)
    + (e1.petal_width - e2.petal_width).powi(2))
    .sqrt()
}

// KNN algorithm to classify a new Iris flower
fn classify_knn(data: &[Iris], k: usize, new_entity: &Iris) -> String {
// Create a vector of distances between `new_entity` and all entities in `data`
let mut distances: Vec<(f64, &Iris)> = data
.iter()
.map(|entity| (euclidean_distance(entity, new_entity), entity))
.collect();

// Sort distances in ascending order
distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());

// Take the labels of the k nearest entities
let mut label_counts = HashMap::new();
for i in 0..k {
 let label = &distances[i].1.species;
 let count = label_counts.entry(label).or_insert(0);
 *count += 1;
}

// Return the label with the highest count
label_counts
.into_iter()
.max_by_key(|&(_, count)| count)
.unwrap()
.0
.clone()
}

fn main() -> Result<(), Box<dyn Error>> {

let file_path: &str = "../datastore/iris_data.csv";

// Read data from CSV
let data = read_data_from_csv(file_path)?;

// Define a new Iris flower to classify
let new_entity = Iris {
    sepal_length: 5.1,
    sepal_width: 3.5,
    petal_length: 1.4,
    petal_width: 0.2,
    species: String::new(), // species is unknown
};

// Classify the new Iris flower using KNN with k=3
let k = 3;
let predicted_label = classify_knn(&data, k, &new_entity);

println!(
    "Classified label for Iris flower (Sepal Length: {}, Sepal Width: {}, Petal Length: {}, Petal Width: {}) is: {}",
new_entity.sepal_length, new_entity.sepal_width, new_entity.petal_length, new_entity.petal_width, predicted_label
);

Ok(())
}