initial commit
22
Cargo.toml
Normal file
@ -0,0 +1,22 @@
|
||||
[package]
|
||||
name = "pytorch"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
|
||||
[[bin]]
|
||||
name = "nn"
|
||||
path = "src/main.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "conv"
|
||||
path = "src/main2.rs"
|
||||
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0.69"
|
||||
colored = "2.0.0"
|
||||
image = "0.24.5"
|
||||
tch = "0.10.2"
|
BIN
data/t10k-images-idx3-ubyte
Normal file
BIN
data/t10k-labels-idx1-ubyte
Normal file
BIN
data/train-images-idx3-ubyte
Normal file
BIN
data/train-labels-idx1-ubyte
Normal file
BIN
imgs/00.png
Normal file
After Width: | Height: | Size: 1.7 KiB |
BIN
imgs/01.png
Normal file
After Width: | Height: | Size: 7.5 KiB |
BIN
imgs/02.png
Normal file
After Width: | Height: | Size: 9.0 KiB |
BIN
imgs/02_1.png
Normal file
After Width: | Height: | Size: 12 KiB |
BIN
imgs/03.png
Normal file
After Width: | Height: | Size: 9.5 KiB |
BIN
imgs/04.png
Normal file
After Width: | Height: | Size: 1.5 KiB |
BIN
imgs/05.png
Normal file
After Width: | Height: | Size: 1.6 KiB |
BIN
imgs/06.png
Normal file
After Width: | Height: | Size: 1.6 KiB |
BIN
imgs/07.png
Normal file
After Width: | Height: | Size: 1.4 KiB |
BIN
imgs/08.png
Normal file
After Width: | Height: | Size: 1.6 KiB |
BIN
imgs/09.png
Normal file
After Width: | Height: | Size: 1.7 KiB |
BIN
imgs/09_1.png
Normal file
After Width: | Height: | Size: 13 KiB |
93
src/main.rs
Normal file
@ -0,0 +1,93 @@
|
||||
use image::io::Reader as ImageReader;
|
||||
use std::path::Path;
|
||||
|
||||
use anyhow::Result;
|
||||
use tch::{nn, nn::Module, nn::OptimizerConfig, Device, Tensor};
|
||||
|
||||
const IMAGE_DIM: i64 = 784;
|
||||
const HIDDEN_NODES: i64 = 128;
|
||||
const LABELS: i64 = 10;
|
||||
|
||||
fn net(vs: &nn::Path) -> impl Module {
|
||||
nn::seq()
|
||||
.add(nn::linear(vs / "layer1", IMAGE_DIM,
|
||||
HIDDEN_NODES, Default::default()))
|
||||
.add_fn(|xs| xs.relu())
|
||||
.add(nn::linear(vs, HIDDEN_NODES,
|
||||
LABELS, Default::default()))
|
||||
}
|
||||
|
||||
pub fn run() -> Result<impl Module> {
|
||||
let m = tch::vision::mnist::load_dir("data")?;
|
||||
println!("train-images: {:?}", m.train_images.size());
|
||||
println!("train-labels: {:?}", m.train_labels.size());
|
||||
println!("test-images: {:?}", m.test_images.size());
|
||||
println!("test-labels: {:?}", m.test_labels.size());
|
||||
|
||||
let vs = nn::VarStore::new(Device::Cpu);
|
||||
let net = net(&vs.root());
|
||||
let mut opt = nn::Adam::default().build(&vs, 1e-3)?;
|
||||
for epoch in 1..100 {
|
||||
let loss = net.forward(&m.train_images)
|
||||
.cross_entropy_for_logits(&m.train_labels);
|
||||
opt.backward_step(&loss);
|
||||
let test_accuracy = net.forward(&m.test_images)
|
||||
.accuracy_for_logits(&m.test_labels);
|
||||
println!(
|
||||
"epoch: {:4} train loss: {:8.5} test acc: {:5.2}%",
|
||||
epoch,
|
||||
f64::from(&loss),
|
||||
100. * f64::from(&test_accuracy),
|
||||
);
|
||||
}
|
||||
Ok(net)
|
||||
}
|
||||
|
||||
fn png_to_tensor(file: &Path) -> Result<Tensor> {
|
||||
let img = ImageReader::open(file)?.decode()?;
|
||||
let img = img.resize(28, 28, image::imageops::FilterType::Lanczos3);
|
||||
let luma = img.to_luma32f();
|
||||
let v = luma.into_vec();
|
||||
Ok(Tensor::of_slice(&v))
|
||||
}
|
||||
|
||||
fn id_image(net: &impl Module, file: &Path) -> Result<bool> {
|
||||
let t = png_to_tensor(file)?;
|
||||
let res = net.forward(&t);
|
||||
|
||||
let sizes = res.size();
|
||||
let mut res2 : Vec<(usize, f64)> = (0..sizes[0] as usize)
|
||||
.map(| i | (i, res.double_value(&[i as i64])))
|
||||
.collect();
|
||||
res2.sort_by(| (_, a), (_, b) | b.partial_cmp(a).unwrap());
|
||||
|
||||
let name = file.file_stem().unwrap().to_str().unwrap();
|
||||
let tbl : Vec<&str> = name.split('_').collect();
|
||||
let nbr = tbl[0].parse::<usize>().unwrap();
|
||||
let is_ok = nbr == res2[0].0;
|
||||
println!("== {:4} => {} // {}", name, res2[0].0,
|
||||
if is_ok { "ok :)" } else { "KO :(" });
|
||||
Ok(is_ok)
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let net = run()?;
|
||||
|
||||
let mut paths: Vec<_> = std::fs::read_dir("./imgs").unwrap()
|
||||
.map(|r| r.unwrap())
|
||||
.collect();
|
||||
paths.sort_by_key(|dir| dir.path());
|
||||
|
||||
let mut is_oks : u32 = 0;
|
||||
let mut totals : u32 = 0;
|
||||
|
||||
for path in paths {
|
||||
let is_ok = id_image(&net, &path.path())?;
|
||||
is_oks += if is_ok { 1 } else { 0 };
|
||||
totals += 1;
|
||||
}
|
||||
|
||||
println!("=> {:.2}%", (is_oks as f32 / totals as f32) * 100.0);
|
||||
|
||||
Ok(())
|
||||
}
|
124
src/main2.rs
Normal file
@ -0,0 +1,124 @@
|
||||
use image::io::Reader as ImageReader;
|
||||
use std::path::Path;
|
||||
use colored::Colorize;
|
||||
|
||||
use anyhow::Result;
|
||||
use tch::{nn, nn::ModuleT, nn::OptimizerConfig, Device, Tensor};
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Net {
|
||||
conv1: nn::Conv2D,
|
||||
conv2: nn::Conv2D,
|
||||
fc1: nn::Linear,
|
||||
fc2: nn::Linear,
|
||||
}
|
||||
|
||||
impl Net {
|
||||
fn new(vs: &nn::Path) -> Net {
|
||||
let conv1 = nn::conv2d(vs, 1, 32, 5, Default::default());
|
||||
let conv2 = nn::conv2d(vs, 32, 64, 5, Default::default());
|
||||
let fc1 = nn::linear(vs, 1024, 1024, Default::default());
|
||||
let fc2 = nn::linear(vs, 1024, 10, Default::default());
|
||||
Net { conv1, conv2, fc1, fc2 }
|
||||
}
|
||||
}
|
||||
|
||||
impl nn::ModuleT for Net {
|
||||
fn forward_t(&self, xs: &Tensor, train: bool) -> Tensor {
|
||||
xs.view([-1, 1, 28, 28])
|
||||
.apply(&self.conv1)
|
||||
.max_pool2d_default(2)
|
||||
.apply(&self.conv2)
|
||||
.max_pool2d_default(2)
|
||||
.view([-1, 1024])
|
||||
.apply(&self.fc1)
|
||||
.relu()
|
||||
.dropout(0.5, train)
|
||||
.apply(&self.fc2)
|
||||
}
|
||||
}
|
||||
|
||||
fn run() -> Result<Net> {
|
||||
let m = tch::vision::mnist::load_dir("data")?;
|
||||
let vs = nn::VarStore::new(Device::cuda_if_available());
|
||||
let net = Net::new(&vs.root());
|
||||
let mut opt = nn::Adam::default().build(&vs, 1e-4)?;
|
||||
for epoch in 1..10 {
|
||||
for (bimages, blabels) in m.train_iter(256)
|
||||
.shuffle().to_device(vs.device()) {
|
||||
let loss = net.forward_t(&bimages, true)
|
||||
.cross_entropy_for_logits(&blabels);
|
||||
opt.backward_step(&loss);
|
||||
}
|
||||
let test_accuracy =
|
||||
net.batch_accuracy_for_logits(&m.test_images,
|
||||
&m.test_labels, vs.device(), 1024);
|
||||
println!("epoch: {:4} test acc: {:5.2}%", epoch, 100. * test_accuracy);
|
||||
if test_accuracy > 0.995 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(net)
|
||||
}
|
||||
|
||||
fn png_to_tensor(file: &Path) -> Result<Tensor> {
|
||||
let img = ImageReader::open(file)?.decode()?;
|
||||
let img = img.resize(28, 28, image::imageops::FilterType::Lanczos3);
|
||||
let luma = img.to_luma32f();
|
||||
let v = luma.into_vec();
|
||||
Ok(Tensor::of_slice(&v))
|
||||
}
|
||||
|
||||
fn id_image(net: &Net, file: &Path) -> Result<bool> {
|
||||
let t = png_to_tensor(file)?;
|
||||
let res = net.forward_t(&t, false);
|
||||
|
||||
let sizes = res.size();
|
||||
|
||||
let mut res2 : Vec<(usize, f64)> = (0..sizes[1] as usize)
|
||||
.map(| i | (i, res.get(0).double_value(&[i as i64])))
|
||||
.collect();
|
||||
res2.sort_by(| (_, a), (_, b) | b.partial_cmp(a).unwrap());
|
||||
|
||||
let name = file.file_stem().unwrap().to_str().unwrap();
|
||||
let tbl : Vec<&str> = name.split('_').collect();
|
||||
let nbr = tbl[0].parse::<usize>().unwrap();
|
||||
let is_ok = nbr == res2[0].0;
|
||||
let certainty = res2[0].1 / res2[1].1;
|
||||
let s = format!("certainty: {:.2} // {}",
|
||||
certainty,
|
||||
if is_ok { "ok :)" } else { "KO :(" });
|
||||
println!("== {:4} => {} // {}", name, res2[0].0,
|
||||
if is_ok && certainty > 2.0 {
|
||||
s.green()
|
||||
} else if is_ok {
|
||||
s.yellow()
|
||||
} else {
|
||||
s.red()
|
||||
});
|
||||
|
||||
Ok(is_ok)
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let net = run()?;
|
||||
|
||||
let mut paths: Vec<_> = std::fs::read_dir("./imgs").unwrap()
|
||||
.map(|r| r.unwrap())
|
||||
.collect();
|
||||
paths.sort_by_key(|dir| dir.path());
|
||||
|
||||
let mut is_oks : u32 = 0;
|
||||
let mut totals : u32 = 0;
|
||||
|
||||
for path in paths {
|
||||
let is_ok = id_image(&net, &path.path())?;
|
||||
is_oks += if is_ok { 1 } else { 0 };
|
||||
totals += 1;
|
||||
}
|
||||
|
||||
println!("=> {:.2}%", (is_oks as f32 / totals as f32) * 100.0);
|
||||
|
||||
Ok(())
|
||||
}
|