125 lines
3.4 KiB
Rust
125 lines
3.4 KiB
Rust
use image::io::Reader as ImageReader;
|
|
use std::path::Path;
|
|
use colored::Colorize;
|
|
|
|
use anyhow::Result;
|
|
use tch::{nn, nn::ModuleT, nn::OptimizerConfig, Device, Tensor};
|
|
|
|
#[derive(Debug)]
|
|
struct Net {
|
|
conv1: nn::Conv2D,
|
|
conv2: nn::Conv2D,
|
|
fc1: nn::Linear,
|
|
fc2: nn::Linear,
|
|
}
|
|
|
|
impl Net {
|
|
fn new(vs: &nn::Path) -> Net {
|
|
let conv1 = nn::conv2d(vs, 1, 32, 5, Default::default());
|
|
let conv2 = nn::conv2d(vs, 32, 64, 5, Default::default());
|
|
let fc1 = nn::linear(vs, 1024, 1024, Default::default());
|
|
let fc2 = nn::linear(vs, 1024, 10, Default::default());
|
|
Net { conv1, conv2, fc1, fc2 }
|
|
}
|
|
}
|
|
|
|
impl nn::ModuleT for Net {
|
|
fn forward_t(&self, xs: &Tensor, train: bool) -> Tensor {
|
|
xs.view([-1, 1, 28, 28])
|
|
.apply(&self.conv1)
|
|
.max_pool2d_default(2)
|
|
.apply(&self.conv2)
|
|
.max_pool2d_default(2)
|
|
.view([-1, 1024])
|
|
.apply(&self.fc1)
|
|
.relu()
|
|
.dropout(0.5, train)
|
|
.apply(&self.fc2)
|
|
}
|
|
}
|
|
|
|
fn run() -> Result<Net> {
|
|
let m = tch::vision::mnist::load_dir("data")?;
|
|
let vs = nn::VarStore::new(Device::cuda_if_available());
|
|
let net = Net::new(&vs.root());
|
|
let mut opt = nn::Adam::default().build(&vs, 1e-4)?;
|
|
for epoch in 1..10 {
|
|
for (bimages, blabels) in m.train_iter(256)
|
|
.shuffle().to_device(vs.device()) {
|
|
let loss = net.forward_t(&bimages, true)
|
|
.cross_entropy_for_logits(&blabels);
|
|
opt.backward_step(&loss);
|
|
}
|
|
let test_accuracy =
|
|
net.batch_accuracy_for_logits(&m.test_images,
|
|
&m.test_labels, vs.device(), 1024);
|
|
println!("epoch: {:4} test acc: {:5.2}%", epoch, 100. * test_accuracy);
|
|
if test_accuracy > 0.995 {
|
|
break;
|
|
}
|
|
}
|
|
|
|
Ok(net)
|
|
}
|
|
|
|
fn png_to_tensor(file: &Path) -> Result<Tensor> {
|
|
let img = ImageReader::open(file)?.decode()?;
|
|
let img = img.resize(28, 28, image::imageops::FilterType::Lanczos3);
|
|
let luma = img.to_luma32f();
|
|
let v = luma.into_vec();
|
|
Ok(Tensor::of_slice(&v))
|
|
}
|
|
|
|
fn id_image(net: &Net, file: &Path) -> Result<bool> {
|
|
let t = png_to_tensor(file)?;
|
|
let res = net.forward_t(&t, false);
|
|
|
|
let sizes = res.size();
|
|
|
|
let mut res2 : Vec<(usize, f64)> = (0..sizes[1] as usize)
|
|
.map(| i | (i, res.get(0).double_value(&[i as i64])))
|
|
.collect();
|
|
res2.sort_by(| (_, a), (_, b) | b.partial_cmp(a).unwrap());
|
|
|
|
let name = file.file_stem().unwrap().to_str().unwrap();
|
|
let tbl : Vec<&str> = name.split('_').collect();
|
|
let nbr = tbl[0].parse::<usize>().unwrap();
|
|
let is_ok = nbr == res2[0].0;
|
|
let certainty = res2[0].1 / res2[1].1;
|
|
let s = format!("certainty: {:.2} // {}",
|
|
certainty,
|
|
if is_ok { "ok :)" } else { "KO :(" });
|
|
println!("== {:4} => {} // {}", name, res2[0].0,
|
|
if is_ok && certainty > 2.0 {
|
|
s.green()
|
|
} else if is_ok {
|
|
s.yellow()
|
|
} else {
|
|
s.red()
|
|
});
|
|
|
|
Ok(is_ok)
|
|
}
|
|
|
|
fn main() -> Result<()> {
|
|
let net = run()?;
|
|
|
|
let mut paths: Vec<_> = std::fs::read_dir("./imgs").unwrap()
|
|
.map(|r| r.unwrap())
|
|
.collect();
|
|
paths.sort_by_key(|dir| dir.path());
|
|
|
|
let mut is_oks : u32 = 0;
|
|
let mut totals : u32 = 0;
|
|
|
|
for path in paths {
|
|
let is_ok = id_image(&net, &path.path())?;
|
|
is_oks += if is_ok { 1 } else { 0 };
|
|
totals += 1;
|
|
}
|
|
|
|
println!("=> {:.2}%", (is_oks as f32 / totals as f32) * 100.0);
|
|
|
|
Ok(())
|
|
}
|