博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
3.2 使用pytorch搭建AlexNet
阅读量:37453 次
发布时间:2020-12-04

本文共 2120 字,大约阅读时间需要 7 分钟。

第一个           关于imagefolder              

上一条文章,学习掌握了os库与json文件,这样我们较为方便地在自己标注数据后,将其划分到两个文件或多个文件夹里。imagefolder就支持这样的数据集类型(个人感觉这样的数据类型是很多不错的)

数据加载器          真正开始使用           还得datafolder

第二个      关于dataset.class_to_idx

将所有花类分出来编码,感觉是生成一个字典吧

感觉意义不大

这一部视频里面写到了json里面,为啥呀?闲得慌

第三部               如何显示图像      

 

第四步                 学习训练步骤

net = AlexNet(num_classes=5, init_weights=True)    net.to(device)    loss_function = nn.CrossEntropyLoss()    # pata = list(net.parameters())    optimizer = optim.Adam(net.parameters(), lr=0.0002)    epochs = 10    save_path = './AlexNet.pth'    best_acc = 0.0    train_steps = len(train_loader)    for epoch in range(epochs):        # train        net.train()        running_loss = 0.0        train_bar = tqdm(train_loader, file=sys.stdout)     #生成一个进度条        for step, data in enumerate(train_bar):            images, labels = data                 #这一步可以略微窥探dataloader中数据的组成格式            optimizer.zero_grad()            outputs = net(images.to(device))            loss = loss_function(outputs, labels.to(device))            loss.backward()            optimizer.step()            # print statistics            running_loss += loss.item()            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,                                                                     epochs,                                                                     loss)        # validate        net.eval()        acc = 0.0  # accumulate accurate number / epoch        with torch.no_grad():            val_bar = tqdm(validate_loader, file=sys.stdout)            for val_data in val_bar:                val_images, val_labels = val_data                outputs = net(val_images.to(device))                predict_y = torch.max(outputs, dim=1)[1]                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()        val_accurate = acc / val_num        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %              (epoch + 1, running_loss / train_steps, val_accurate))        if val_accurate > best_acc: #那还能是零???奇怪极了            best_acc = val_accurate            torch.save(net.state_dict(), save_path)        #在理想路径中保存路径值    print('Finished Training')

  

转载地址:http://rfpowy.baihongyu.com/

你可能感兴趣的文章
C#中的命名空间
查看>>
设计模式——状态模式
查看>>
设计模式——工厂模式
查看>>
Unity中实现有限状态机FSM
查看>>
Unity中实现反弹
查看>>
U3D游戏开发框架(九)——事件序列
查看>>
Unity中解决“SetDestination“ can only be called on an active agent that has been placed on a NavMesh
查看>>
Unity中的刚体
查看>>
Unity中的坐标转换
查看>>
Unity中为什么不能对transform.position.x直接赋值?
查看>>
Unity中物体移动方法详解
查看>>
使用对象池优化性能
查看>>
Unity中的UI方案(基础版)
查看>>
Lua(一)——Lua介绍
查看>>
Lua(二)——环境安装
查看>>
Unity中父子物体的坑
查看>>
基础知识——进位制
查看>>
Lua(十二)——表
查看>>
Lua(十三)——模块与包
查看>>
Lua(四)——变量
查看>>