垃圾分类图像深度学习源码(1)

发布于 2022-08-22  875 次阅读


main.py

import torch
from waste_sorting_dataset import waste_sorting_dataset
from torch.utils.data import DataLoader
from torchvision import models
from train import train
from torch import optim, nn
from loguru import logger
from config import batch_size, num_workers, num_epochs, lr_init, \
    lr_stepsize, weight_decay, device, num_hiddens1, num_hiddens2
logger.add('waste_sorting.log')

if __name__ == '__main__':
    # 采用resnet34网络
    # 核心代码!!!
    net = models.resnet34(weights = models.ResNet34_Weights.IMAGENET1K_V1)
    net.fc = nn.Sequential(
        nn.Linear(512, 4)
        # nn.ReLU(),
        # nn.Dropout(0.5),
        # nn.Linear(num_hiddens1, num_hiddens2),
        # nn.ReLU(),
        # nn.Dropout(0.5),
        # nn.Linear(num_hiddens2, 4)
        # nn.Linear(256, 4)
    )
    # net.load_state_dict(torch.load('./model_5.pt'))
    # 初始化数据
    # 更新train_iter
    def update_train_iter(train_mod_num):
        logger.info('update train_iter ' + str(train_mod_num))
        train_dataset = waste_sorting_dataset(train_mod_num)
        train_loader = DataLoader(dataset = train_dataset, batch_size = batch_size, shuffle = True, pin_memory = True, num_workers = num_workers)                     
        return train_loader
    # 初始化验证集的迭代器
    def init_val_iter():
        val_data = waste_sorting_dataset(-2)
        val_iter = DataLoader(dataset = val_data, batch_size = batch_size, pin_memory = True, num_workers = num_workers)
        logger.info('nums of val_data: '+str(len(val_data)))
        return val_iter
    # 初始化测试集的迭代器
    def init_test_iter():
        test_data = waste_sorting_dataset(-1)
        test_iter = DataLoader(dataset = test_data, batch_size = batch_size, pin_memory = True, num_workers = num_workers)
        logger.info('nums of test_iter: '+ str(len(test_data)))
        return test_iter
    # Adam优化
    optimizer = optim.Adam(net.parameters(), lr = lr_init, weight_decay = weight_decay)
    # 调整学习率
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size = lr_stepsize, gamma = 0.1)

    # 训练
    train(net, update_train_iter, init_val_iter, optimizer, scheduler, device, num_epochs)
    
    # 测试

train.py

from torch import gather
from loguru import logger
import torch
import time
import random
from draw import semilogy

# 开始训练
# 传入函数而非对象 减少内存的占用 用的时候再实例化
def train(net, update_train_iter, init_val_iter, optimizer, scheduler, device, num_epochs):
    from config import root_path
    from transform_data_to_tensor import reload_file
    # 转移网络至gpu
    net = net.to(device)
    logger.critical('training on: '+ device.__str__())
    # 分类问题采用交叉熵计算误差
    loss = torch.nn.CrossEntropyLoss()
    save_time = 18
    train_ac = []
    val_ac = []
    for epoch in range(num_epochs):
        # 20为一组重载数据
        # if epoch % 20 == 0:
        if epoch != 0 and epoch % 20 == 0:
            epoch_num = int(epoch / 20)
            logger.info('making datas for training group ' + str(epoch_num))
            save_time = reload_file(root_path, epoch_num)
        train_loss_sum, train_accuracy_sum, n, batch_count, start = 0.0, 0.0, 0, 0, time.time()
        random_train_mod = random.sample(range(0, save_time), save_time)
        # 训练 
        for train_mod_num in random_train_mod:
            iter = update_train_iter(train_mod_num)
            for img, label in iter:
                # 转移矩阵至cuda
                img = img.to(device)
                label = label.to(device)
                # 计算值
                label_hat = net(img)
                # 计算误差
                l = loss(label_hat, label)
                # 梯度清零
                optimizer.zero_grad()
                # 求梯度
                l.backward()
                # 优化
                optimizer.step()
                # 计算训练误差总值
                train_loss_sum += l.cpu().item()
                # 计算训练准确度
                train_accuracy_sum += (label_hat.argmax(dim = 1) == label).sum().cpu().item()
                n += label.shape[0]
                batch_count += 1
        train_ac.append(train_accuracy_sum/batch_count)
        #保存模型
        net.eval()
        torch.save(net.state_dict(), './pt/model_' + str(epoch) + '.pt')
        net.train()
        scheduler.step()
        # 每5轮进行一次验证
        # if (epoch + 1) % 5 == 0:
        #     # 验证
        #     iter = init_val_iter()
        #     test_accuracy = evaluate_accuracy(iter, net)
        #     logger.critical('epoch %d, loss %.4f, train acc %.3f, val acc %.3f, time %.1f sec'
        #         % (epoch + 1, train_loss_sum / batch_count, train_accuracy_sum / n, test_accuracy, time.time() - start))      
        # else:
        #     logger.critical('epoch %d, loss %4f, train acc %.3f'
        #     % (epoch + 1, train_loss_sum / batch_count, train_accuracy_sum / n))

        # 每一轮验证
        iter = init_val_iter()
        test_accuracy = evaluate_accuracy(iter, net)
        val_ac.append(test_accuracy)
        logger.critical('epoch %d, loss %.4f, train acc %.3f, val acc %.3f, time %.1f sec'
            % (epoch + 1, train_loss_sum / batch_count, train_accuracy_sum / n, test_accuracy, time.time() - start))
        # semilogy(range(1, epoch+2), train_ac, 'epochs', 'accuracy', range(1, epoch+2), val_ac, ['train', 'valid'])
        

# 计算准确率
def evaluate_accuracy(data_iter, net, device=None):
    if device is None and isinstance(net, torch.nn.Module):
        # 如果没指定device就使用net的device
        device = list(net.parameters())[0].device
    acc_sum, n = 0.0, 0
    with torch.no_grad():
        for X, y in data_iter:
            if isinstance(net, torch.nn.Module):
                net.eval() # 评估模式, 这会关闭dropout
                acc_sum += (net(X.to(device)).argmax(dim=1) == y.to(device)).float().sum().cpu().item()
                net.train() # 改回训练模式
            else: # 自定义的模型
                if('is_training' in net.__code__.co_varnames): # 如果有is_training这个参数
                    # 将is_training设置成False
                    acc_sum += (net(X, is_training=False).argmax(dim=1) == y).float().sum().item() 
                else:
                    acc_sum += (net(X).argmax(dim=1) == y).float().sum().item() 
            n += y.shape[0]
    return acc_sum / n
        

waste_sorting_dataset.py

from torch.utils.data import Dataset, DataLoader
import random
import torch

#-2为验证模式, -1为测试模式, 0以上为训练模式
class waste_sorting_dataset(Dataset):
    def __init__(self, train_mod_num):
        if train_mod_num == -2:
            self.imgs = torch.load('./pt/val.pt')
        elif train_mod_num == -1:
            self.imgs = torch.load('./pt/test.pt')
        else:
            self.imgs = torch.load('./pt/train_' + str(train_mod_num) + '.pt')
    def __getitem__(self, index):
        img, label = self.imgs[index]
        return img, label
    def __len__(self):
        return len(self.imgs)


# if __name__ == "__main__":
#     # 外部打乱 + 内部的shuffle = True打乱获得随机效果
#     random_train_mod = random.sample(range(0, 13), 13)
#     print(random_train_mod)
#     for train_mod_num in random_train_mod:
#         img, label = run_train_loader(train_mod_num)
#         print(train_mod_num)  

海纳百川 有容乃大