博客
关于我
Pytorch学习笔记(二)自用
阅读量:678 次
发布时间:2019-03-17

本文共 2839 字,大约阅读时间需要 9 分钟。

PyTorch 模型实战系列学习笔记

本周学习内容:

  • PyTorch 实现CNN分类器,识别MNIST数据集
  • 以CNN为例,实现GPU加速
  • PyTorch 实现RNN分类器,识别MNIST数据集
  • PyTorch 实现RNN回归,用sin去拟合cos
  • PyTorch 实现自编码器
  • PyTorch 实现DQN,模拟小车顶棍子
  • PyTorch 实现GAN,画曲线

环境配置:

  • Python=3.7
  • torch=1.6.0
  • torchvision=0.7.0

6. CNN

CNN 模型实现

def __init__(self):    super(CNN, self).__init__()    self.conv1 = nn.Sequential(        nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2),        nn.ReLU(),        nn.MaxPool2d(kernel_size=2)    )    self.conv2 = nn.Sequential(        nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2),        nn.ReLU(),        nn.MaxPool2d(kernel_size=2)    )    self.out = nn.Linear(32*7*7, 10)

模型功能

  • 输入处理:批量大小默认为50,输入形状为(B, 1, 28, 28)
  • 过�会:两层卷积层,每层包含卷积、激活函数和池化操作
  • 输出层:全连接层输出10个分类
  • 前向传递:输入经过conv1和conv2两层后,形状变为(B, 32, 7, 7),再转换为(B, 3277),通过出层得到预测结果
  • 损失函数:交叉熵损失

训练过程

  • 数据集:使用MNIST数据集,训练集60000样本,测试集5000样本
  • 优化器:Adam优化器,学习率为0.001
  • 可视化:每50个步骤输出训练损失和测试准确率
  • 预测结果:测试集预测结果与真实值对比,输出为 [7,2,1,0,4,1,4,9,5,9]

7. RNN_classifier

RNN 类ification 模型

def forward(self, x):    r_out, (h_n, h_c) = self.rnn(x, None)    return self.out(r_out[:, -1, :])

模型架构

  • 输入层:时间步为28,输入维度为28×1
  • RNN 层:LSTM层,初始隐藏状态为零
  • 输出层:全连接层输出10个分类
  • 损失函数:交叉熵损失

训练过程

  • 数据集:MNIST数据集,批量大小为64
  • 优化器:Adam优化器,学习率为0.01
  • 可视化:每50个步骤输出训练损失和测试准确率
  • 预测结果:测试集预测结果与真实值对比,输出为 [7,2,1,0,4,1,4,9,8,9]

8. RNN_regressor

RNN 回归 模型

def forward(self, x, h_state):    r_out, h_state = self.rnn(x, h_state)    return self.out(r_out[:, -1, :]), h_state

模型特点

  • 小范围应用:用于拟合sin函数输出的cos函数
  • 输入处理:输入为时间序列,形状为(B, 1)
  • RNN 层:LSTM层,隐藏状态大小为32
  • 输出层:线性层输出1个预测值
  • 训练目标:最小化预测值与真实值的MSE损失

训练过程

  • 训练数据:生成的sin数据和对应的cos数据
  • 批量大小:64
  • 学习率:0.01
  • 动态图像:每50步更新一次,动态显示预测值与真实值的波形

9. Autoencoder

自编码器 模型

编码器部分

self.encoder = nn.Sequential(    nn.Linear(28*28, 128),    nn.Tanh(),    nn.Linear(128, 64),    nn.Tanh(),    nn.Linear(64, 12),    nn.Tanh(),    nn.Linear(12, 3))

解码器部分

self.decoder = nn.Sequential(    nn.Linear(3, 12),    nn.Tanh(),    nn.Linear(12, 64),    nn.Tanh(),    nn.Linear(64, 128),    nn.Tanh(),    nn.Linear(128, 28*28),    nn.Sigmoid())

学习目的

  • 自编码:对输入图像进行压缩再解码
  • 无监督学习:利用MNIST数据集的图片进行训练
  • 可视化:动态显示压缩后的特征映射

训练过程

  • 批量大小:50
  • 学习率:0.001
  • 训练周期:5个 epoch
  • 动态图像:每100步更新一次,展示压缩后的特征图和原图

10. DQN

DQN 模型结构

class Net(nn.Module):    def __init__(self):        self.fc1 = nn.Linear(N_STATES, 50)        self.fc1.weight.data.normal_(0, 0.1)        self.out = nn.Linear(50, N_ACTIONS)        self.out.weight.data.normal_(0, 0.1)        def forward(self, x):        x = self.fc1(x)        x = F.relu(x)        return self.out(x)

训练过程

  • 强化学习框架:基于CartPole环境
  • 记忆库:存储经验回放5000步
  • 学习策略:在线学习,适当替换目标网络参数
  • 动作选择:ε-greedy策略,平衡探索与利用
  • 可视化:动态显示 每次行动结果和奖励

11. GAN

GAN 模型架构

G = nn.Sequential(    nn.Linear(N_IDEAS, 128),    nn.ReLU(),    nn.Linear(128, ART_COMPONENTS))D = nn.Sequential(    nn.Linear(ART_COMPONENTS, 128),    nn.ReLU(),    nn.Linear(128, 1),    nn.Sigmoid())

训练过程

  • 艺术生成:G生成抽象艺术,D判别真实度
  • 交叉训练:G试图生成接近真实数据的样本,D试图识别真实数据
  • 动态可视化:每50次迭代展示生成样本和训练过程

以上就是本周学习内容的各个实战报告。

转载地址:http://pwdhz.baihongyu.com/

你可能感兴趣的文章
PHP反射机制
查看>>
php取当天的最后一秒_Docker快速搭建PHP开发环境详细教程
查看>>
php取绝对值
查看>>
PHP变量内容的获取
查看>>
php各种常用的算法
查看>>
php各种缓存策略对比
查看>>
RabbitMQ高级特性 - 消息分发(限流、负载均衡)
查看>>
php后台“爬虫”模拟登录第三方系统
查看>>
php后台的在控制器中就可以实现阅读数增加
查看>>
php命令行生成项目结构
查看>>
php命名空间
查看>>
PHP命名空间带来的干扰
查看>>
PHP和MySQL Web开发从新手到高手,第1天-搭建PHP开发环境
查看>>
php商店管理系统,基于PHP的商店管理系统.doc
查看>>
PHP四大主流框架的优缺点总结
查看>>
PHP图片处理—PNG透明缩放并生成灰图
查看>>
php在liunx系统中设置777权限不起作用解决方法
查看>>
PHP基于openssl实现的非对称加密操作
查看>>
php基本符号大全
查看>>
php基础篇-二维数组排序 array_multisort
查看>>