94 lines
2.7 KiB
Rust
94 lines
2.7 KiB
Rust
use image::io::Reader as ImageReader;
|
|
use std::path::Path;
|
|
|
|
use anyhow::Result;
|
|
use tch::{nn, nn::Module, nn::OptimizerConfig, Device, Tensor};
|
|
|
|
const IMAGE_DIM: i64 = 784;
|
|
const HIDDEN_NODES: i64 = 128;
|
|
const LABELS: i64 = 10;
|
|
|
|
fn net(vs: &nn::Path) -> impl Module {
|
|
nn::seq()
|
|
.add(nn::linear(vs / "layer1", IMAGE_DIM,
|
|
HIDDEN_NODES, Default::default()))
|
|
.add_fn(|xs| xs.relu())
|
|
.add(nn::linear(vs, HIDDEN_NODES,
|
|
LABELS, Default::default()))
|
|
}
|
|
|
|
pub fn run() -> Result<impl Module> {
|
|
let m = tch::vision::mnist::load_dir("data")?;
|
|
println!("train-images: {:?}", m.train_images.size());
|
|
println!("train-labels: {:?}", m.train_labels.size());
|
|
println!("test-images: {:?}", m.test_images.size());
|
|
println!("test-labels: {:?}", m.test_labels.size());
|
|
|
|
let vs = nn::VarStore::new(Device::Cpu);
|
|
let net = net(&vs.root());
|
|
let mut opt = nn::Adam::default().build(&vs, 1e-3)?;
|
|
for epoch in 1..100 {
|
|
let loss = net.forward(&m.train_images)
|
|
.cross_entropy_for_logits(&m.train_labels);
|
|
opt.backward_step(&loss);
|
|
let test_accuracy = net.forward(&m.test_images)
|
|
.accuracy_for_logits(&m.test_labels);
|
|
println!(
|
|
"epoch: {:4} train loss: {:8.5} test acc: {:5.2}%",
|
|
epoch,
|
|
f64::from(&loss),
|
|
100. * f64::from(&test_accuracy),
|
|
);
|
|
}
|
|
Ok(net)
|
|
}
|
|
|
|
fn png_to_tensor(file: &Path) -> Result<Tensor> {
|
|
let img = ImageReader::open(file)?.decode()?;
|
|
let img = img.resize(28, 28, image::imageops::FilterType::Lanczos3);
|
|
let luma = img.to_luma32f();
|
|
let v = luma.into_vec();
|
|
Ok(Tensor::of_slice(&v))
|
|
}
|
|
|
|
fn id_image(net: &impl Module, file: &Path) -> Result<bool> {
|
|
let t = png_to_tensor(file)?;
|
|
let res = net.forward(&t);
|
|
|
|
let sizes = res.size();
|
|
let mut res2 : Vec<(usize, f64)> = (0..sizes[0] as usize)
|
|
.map(| i | (i, res.double_value(&[i as i64])))
|
|
.collect();
|
|
res2.sort_by(| (_, a), (_, b) | b.partial_cmp(a).unwrap());
|
|
|
|
let name = file.file_stem().unwrap().to_str().unwrap();
|
|
let tbl : Vec<&str> = name.split('_').collect();
|
|
let nbr = tbl[0].parse::<usize>().unwrap();
|
|
let is_ok = nbr == res2[0].0;
|
|
println!("== {:4} => {} // {}", name, res2[0].0,
|
|
if is_ok { "ok :)" } else { "KO :(" });
|
|
Ok(is_ok)
|
|
}
|
|
|
|
fn main() -> Result<()> {
|
|
let net = run()?;
|
|
|
|
let mut paths: Vec<_> = std::fs::read_dir("./imgs").unwrap()
|
|
.map(|r| r.unwrap())
|
|
.collect();
|
|
paths.sort_by_key(|dir| dir.path());
|
|
|
|
let mut is_oks : u32 = 0;
|
|
let mut totals : u32 = 0;
|
|
|
|
for path in paths {
|
|
let is_ok = id_image(&net, &path.path())?;
|
|
is_oks += if is_ok { 1 } else { 0 };
|
|
totals += 1;
|
|
}
|
|
|
|
println!("=> {:.2}%", (is_oks as f32 / totals as f32) * 100.0);
|
|
|
|
Ok(())
|
|
}
|