use image::io::Reader as ImageReader; use std::path::Path; use anyhow::Result; use tch::{nn, nn::Module, nn::OptimizerConfig, Device, Tensor}; const IMAGE_DIM: i64 = 784; const HIDDEN_NODES: i64 = 128; const LABELS: i64 = 10; fn net(vs: &nn::Path) -> impl Module { nn::seq() .add(nn::linear(vs / "layer1", IMAGE_DIM, HIDDEN_NODES, Default::default())) .add_fn(|xs| xs.relu()) .add(nn::linear(vs, HIDDEN_NODES, LABELS, Default::default())) } pub fn run() -> Result { let m = tch::vision::mnist::load_dir("data")?; println!("train-images: {:?}", m.train_images.size()); println!("train-labels: {:?}", m.train_labels.size()); println!("test-images: {:?}", m.test_images.size()); println!("test-labels: {:?}", m.test_labels.size()); let vs = nn::VarStore::new(Device::Cpu); let net = net(&vs.root()); let mut opt = nn::Adam::default().build(&vs, 1e-3)?; for epoch in 1..100 { let loss = net.forward(&m.train_images) .cross_entropy_for_logits(&m.train_labels); opt.backward_step(&loss); let test_accuracy = net.forward(&m.test_images) .accuracy_for_logits(&m.test_labels); println!( "epoch: {:4} train loss: {:8.5} test acc: {:5.2}%", epoch, f64::from(&loss), 100. * f64::from(&test_accuracy), ); } Ok(net) } fn png_to_tensor(file: &Path) -> Result { let img = ImageReader::open(file)?.decode()?; let img = img.resize(28, 28, image::imageops::FilterType::Lanczos3); let luma = img.to_luma32f(); let v = luma.into_vec(); Ok(Tensor::of_slice(&v)) } fn id_image(net: &impl Module, file: &Path) -> Result { let t = png_to_tensor(file)?; let res = net.forward(&t); let sizes = res.size(); let mut res2 : Vec<(usize, f64)> = (0..sizes[0] as usize) .map(| i | (i, res.double_value(&[i as i64]))) .collect(); res2.sort_by(| (_, a), (_, b) | b.partial_cmp(a).unwrap()); let name = file.file_stem().unwrap().to_str().unwrap(); let tbl : Vec<&str> = name.split('_').collect(); let nbr = tbl[0].parse::().unwrap(); let is_ok = nbr == res2[0].0; println!("== {:4} => {} // {}", name, res2[0].0, if is_ok { "ok :)" } else { "KO :(" }); Ok(is_ok) } fn main() -> Result<()> { let net = run()?; let mut paths: Vec<_> = std::fs::read_dir("./imgs").unwrap() .map(|r| r.unwrap()) .collect(); paths.sort_by_key(|dir| dir.path()); let mut is_oks : u32 = 0; let mut totals : u32 = 0; for path in paths { let is_ok = id_image(&net, &path.path())?; is_oks += if is_ok { 1 } else { 0 }; totals += 1; } println!("=> {:.2}%", (is_oks as f32 / totals as f32) * 100.0); Ok(()) }