rust-pytorch/src/main.rs

94 lines
2.7 KiB
Rust

use image::io::Reader as ImageReader;
use std::path::Path;
use anyhow::Result;
use tch::{nn, nn::Module, nn::OptimizerConfig, Device, Tensor};
const IMAGE_DIM: i64 = 784;
const HIDDEN_NODES: i64 = 128;
const LABELS: i64 = 10;
fn net(vs: &nn::Path) -> impl Module {
nn::seq()
.add(nn::linear(vs / "layer1", IMAGE_DIM,
HIDDEN_NODES, Default::default()))
.add_fn(|xs| xs.relu())
.add(nn::linear(vs, HIDDEN_NODES,
LABELS, Default::default()))
}
pub fn run() -> Result<impl Module> {
let m = tch::vision::mnist::load_dir("data")?;
println!("train-images: {:?}", m.train_images.size());
println!("train-labels: {:?}", m.train_labels.size());
println!("test-images: {:?}", m.test_images.size());
println!("test-labels: {:?}", m.test_labels.size());
let vs = nn::VarStore::new(Device::Cpu);
let net = net(&vs.root());
let mut opt = nn::Adam::default().build(&vs, 1e-3)?;
for epoch in 1..100 {
let loss = net.forward(&m.train_images)
.cross_entropy_for_logits(&m.train_labels);
opt.backward_step(&loss);
let test_accuracy = net.forward(&m.test_images)
.accuracy_for_logits(&m.test_labels);
println!(
"epoch: {:4} train loss: {:8.5} test acc: {:5.2}%",
epoch,
f64::from(&loss),
100. * f64::from(&test_accuracy),
);
}
Ok(net)
}
fn png_to_tensor(file: &Path) -> Result<Tensor> {
let img = ImageReader::open(file)?.decode()?;
let img = img.resize(28, 28, image::imageops::FilterType::Lanczos3);
let luma = img.to_luma32f();
let v = luma.into_vec();
Ok(Tensor::of_slice(&v))
}
fn id_image(net: &impl Module, file: &Path) -> Result<bool> {
let t = png_to_tensor(file)?;
let res = net.forward(&t);
let sizes = res.size();
let mut res2 : Vec<(usize, f64)> = (0..sizes[0] as usize)
.map(| i | (i, res.double_value(&[i as i64])))
.collect();
res2.sort_by(| (_, a), (_, b) | b.partial_cmp(a).unwrap());
let name = file.file_stem().unwrap().to_str().unwrap();
let tbl : Vec<&str> = name.split('_').collect();
let nbr = tbl[0].parse::<usize>().unwrap();
let is_ok = nbr == res2[0].0;
println!("== {:4} => {} // {}", name, res2[0].0,
if is_ok { "ok :)" } else { "KO :(" });
Ok(is_ok)
}
fn main() -> Result<()> {
let net = run()?;
let mut paths: Vec<_> = std::fs::read_dir("./imgs").unwrap()
.map(|r| r.unwrap())
.collect();
paths.sort_by_key(|dir| dir.path());
let mut is_oks : u32 = 0;
let mut totals : u32 = 0;
for path in paths {
let is_ok = id_image(&net, &path.path())?;
is_oks += if is_ok { 1 } else { 0 };
totals += 1;
}
println!("=> {:.2}%", (is_oks as f32 / totals as f32) * 100.0);
Ok(())
}