Never stop learning, approaching the AI
  • 😄About
  • 🅰️AI-Candy on YouTube
    • 🎬YouTube 视频列表
    • 💠AI项目本地部署
      • 1️⃣安装 Pytorch 运行环境
      • 2️⃣在Python虚拟环境下使用 VS Code or PyCharm
        • VS Code Python format
      • 3️⃣Create python virtual environment in Linux
    • 👽AI 入门系列课程
      • 0️向量及向量运算
        • Reshape array code
      • 1️人工智能,机器学习,深度学习 和神经网络的区别
      • 2️卷积神经网络(CNN)
        • 3D CNN sample code
      • 3️Transformer 原理
      • 4️前馈神经网络 (FNN)
  • 🎭Artificial intelligence
    • 1️⃣Deep learning / machine learning
      • 👉Deep Learning Resources
      • 👉Deep learning notes
    • 2️⃣Python
      • Youtube音乐下载
      • Pytorch 安装环境配置 (old)
        • TorchEEG
      • Anaconda3 Python path
      • Data
      • IEEE-754 Floating Point Converter
        • ieee 754 conversion function
      • 文件读写
      • 文本清理
      • Python 下载在线视频
      • 修改Jupyter Notebook 默认工作目录
    • 3️⃣AI Websites
  • 🪤Programming
    • SQL
      • Delete data and reset auto-increment ID
    • Angular
      • Angular-datatables, dd/MM/yyyy, sorting (no paged list)
      • Datatables save state using localStorage
      • Variable storage method
      • Colour picker
      • Error fix for click columns on Datatable
      • Auto address use Google place
      • Auto address use Azure Maps
      • Upload file to Server
      • Validators.required OnChange input
      • Date, Time field
      • VS Code: Auto add missing imports when save
      • Datatable setting
      • Date time format
      • sticky <th> and <td> content
      • Filter booked time
      • Dropdown time selection with interval
      • Angular date online test
      • Updating data without refreshing the page
      • Object array sort and sum
      • Multi-type of columns use in one column
      • Select button for datatables
      • Switch button and event
      • Delete column from Array
      • Three-layer structure
      • Remove shadow when print mat-dialog content
      • JSON Parse && Object Array
      • Detect unused import in Typescript
      • Change location using radio button
      • display multi line message in the Toastr
      • Custom LOCALE_ID
      • Batch add data from csv to API server
        • Angular read csv and upload to Server
      • USB Port reader Web solution
      • Debug Angular app using JavaScript Debugger in VS Code
      • Skills
        • FormData & FormGroup to JSON
        • Dropdown list (customer)
        • Get current datetime
        • Get first day of year, month, and date
        • Call a function in a forEach loop
        • disable and readonly
        • Form element value
        • HTML input type
        • Input pattern (validation)
      • Display pipe (UI format)
        • Input upper case and button checked
        • Icons (Bootstrap and CoreUI)
        • Page Refresh
        • Selection list (two ways)
        • onChange Selection event
        • Random Password and Toggle
        • Password match
        • Select checkbox disable
      • Print and save to PDF
      • Import JS into Angular
      • LocalStorage
      • Angular DataTable
        • Data sort
        • A sample usage
        • Angular DataTable server side big data query
      • Change chart.js chart type
      • Angular UI - .NET API - .NET Auth
      • Angular - .NET API
      • *ngIf else && change to @if
      • Angular add reCAPTCHA v3 (Google)
      • Angular update
        • Update from v13 to v15
        • Update from v15 to v18
      • Angular application version central
      • Face detection
        • Face-api.js
      • Angular, Node version compatibility matrix
      • Clear cache
      • Angular oauth2 OIDC
      • Angular add header
    • .NET Skills
      • Add ID manually
      • Auto Mapping
        • Ignore Nesting
        • Startup setting
        • Datetime processing in AutoMapping
        • AutoMapper example
      • Validation filter
      • BaseController
      • Group by many
      • Database first, scaffold to class
      • Log setting and exception handler
      • Update appsetting.json value
      • Azure service bus message (queue)
      • Read appsetting.json value
      • Auth get user info by email
      • Azure Time zone
      • .NET API Add Service
      • Object comparison
      • Coravel Schedule
        • Read appsettings.json
      • .Net Core RDLC Report, Coravel and Email
      • Check Network and SQL server connections.
      • Datatime custom format
      • Many to Many EF
        • Many to Many CheckBox
      • PDFpig: Send Email with PDF attachment
      • .NET Core Middleware order
      • .NET API add Worker Service
      • .NET Router
      • Partial columns update
      • Add and Delete
      • 图片自适应宽度
      • ASP.NET Identity
      • Upload file to Azure
        • Upload file to Blob
      • Developer Guide
      • Code first one-many
      • ASP.NET MVC 5 Custom Error Page
      • VS can't debug
      • 通过邮编查 NSW COVID-19 感染人数
      • Jquery File Upload
      • Jquery Datepicker
      • ajax delete file from server
      • Autofac in MVC
      • Autofac in .NET Core
      • .NET Core
      • HTTP Return code
      • IdentityServer4
    • Power BI
      • Add parameter to PowerBI report
      • Convert UTC to Local time
      • Python in PowerBI
        • IEEE-754 conversion
      • PowerBI embed app - Server
      • PowerBI embed app - Client
        • Setting on portal
    • Azure service
      • Key Vault
      • Service bus - queue
      • Power Automate
      • Kusto Query Language
      • Azure Data Explorer
      • Reserved keyword on Azure Error
      • SQL Azure time convert
    • Azure blob
      • Azure blob setting
      • Display image from Blob
      • Upload image to Blob through .NET API
    • Html Bootstrap Icon, colour, size
      • Html spacing
      • Html text alignment
    • Video stream - JsMpeg
      • SSL - generate key
      • Client (SSL)
      • Websocket-Server (SSL)
      • Play RTSP video stream
    • ⏰Time Zone
      • datetime-local set date range
      • 🕐Get data by local time (UI, API)
      • 🕑Add offset hours for local UI and report
      • 🕒UTC time and Datetime convert
      • 🕓Angular - Timezone selection
      • 🕔Angular - Convert UTC to local time
      • 🕕C# Time Zone
  • >>>>>>>>>>>>>>>>>>>>>>>>>>>>
  • 🪜Apps and Skills
    • 1️Windows system app skills
      • Brother HL-2130 打印机 Toner 报警
      • VS Code 快捷键
      • Check SHA256 on windows
      • blob 视频下载
      • Photoshop 制作证件照片
      • 获取 Windows Key
      • 10进制36进制互转
      • Error when publish to Azure
      • Disable windows automatic update
      • Outlook setup for Yahoo Email
      • IIS setting
      • Windows 8/10, IIS Service
      • 安装程序出错 2052,2053 报警
      • 6 Yao Chinese UI
    • 2️Linux command
    • 3️Git command
    • 4️Bitbucket
    • 5️Gitbook Skills
    • 6️GitHub Desktop
    • 7️⃣EndNote
      • EndNote V21
      • Endnote使用技巧
      • 批量删除/修改Endnote 中 notes 栏内容
  • Android mobile connect PC
  • 💎USEFUL LINKS
    • 1️Coding websites
      • Website links
    • 2️Windows 平台工具,网站
    • 3️PotPlayer 设置
  • >>>>>>>>>>>>>>>>>>>>>>>>>>>>
  • 🚩Research >>EEG
    • 1️EEG基本知识的理论介绍
      • EEG 简介
      • EEG 的节律信号
      • EEG电极帽
      • EEG 伪迹
      • ERP 介绍
      • ERP 成分
      • EEG 数据分析软件
    • 2️LSL 应用
    • 3️EEG公开数据集汇总整理
    • 4️REDCap
      • Migration (Export & Import)
    • 5️⃣ScaneR
  • ☕Buy me a coffee
Powered by GitBook
On this page

Was this helpful?

  1. AI-Candy on YouTube
  2. AI 入门系列课程
  3. 卷积神经网络(CNN)

3D CNN sample code

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

# 定义一个简单的合成数据集
class SyntheticDataset(Dataset):
    def __init__(self, num_samples, num_classes):
        self.num_samples = num_samples
        self.num_classes = num_classes
        self.data = torch.randn(num_samples, 1, 64, 64, 64)
        self.targets = torch.randint(0, num_classes, (num_samples,))
        self.normalize_data()

    def normalize_data(self):
        # 归一化数据,使每个样本的数据值在0到1之间
        self.data = self.data / self.data.max()

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.data[idx], self.targets[idx]

# 定义3D CNN模型
class Simple3DCNN(nn.Module):
    def __init__(self, num_classes):
        super(Simple3DCNN, self).__init__()
        self.conv1 = nn.Conv3d(1, 16, kernel_size=3, padding=1)
        self.pool = nn.MaxPool3d(2, 2)
        self.conv2 = nn.Conv3d(16, 32, kernel_size=3, padding=1)
        self.conv3 = nn.Conv3d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64 * 8 * 8 * 8, 512)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 创建合成数据集实例
num_samples = 100
num_classes = 4
dataset = SyntheticDataset(num_samples, num_classes)

# 创建数据加载器
batch_size = 10
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 实例化模型
model = Simple3DCNN(num_classes)

# 损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练模型
def train_model(model, train_loader, criterion, optimizer, num_epochs=10):
    loss_list = []
    accuracy_list = []

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        running_corrects = 0

        for inputs, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            running_corrects += torch.sum(preds == labels.data)

        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_acc = running_corrects.double() / len(train_loader.dataset)
        loss_list.append(epoch_loss)
        accuracy_list.append(epoch_acc.item())

        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}')

    # 绘制损失和精度曲线
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.plot(loss_list, label='Loss')
    plt.title('Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(accuracy_list, label='Accuracy')
    plt.title('Training Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.tight_layout()
    plt.show()

# 调用训练函数
train_model(model, train_loader, criterion, optimizer, num_epochs=10)
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

# 定义一个简单的合成数据集
class SyntheticDataset(Dataset):
    def __init__(self, num_samples, num_classes):
        self.num_samples = num_samples
        self.num_classes = num_classes
        self.data = torch.randn(num_samples, 1, 64, 64, 5)
        self.targets = torch.randint(0, num_classes, (num_samples,))
        self.normalize_data()

    def normalize_data(self):
        # 归一化数据,使每个样本的数据值在0到1之间
        self.data = self.data / self.data.max()

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.data[idx], self.targets[idx]


# 定义3DCNN模型
class Simple3DCNN(nn.Module):

    def __init__(self, num_classes=4):
        super(Simple3DCNN, self).__init__()
        # 定义3D卷积层,输入通道数为1
        self.conv1 = nn.Conv3d(in_channels=1,
                               out_channels=16,
                               kernel_size=(3, 3, 3),
                               padding=1)
        self.conv2 = nn.Conv3d(in_channels=16,
                               out_channels=32,
                               kernel_size=(3, 3, 3),
                               padding=1)
        # 定义池化层
        self.pool = nn.MaxPool3d(kernel_size=(2, 2, 2))
        # 定义全连接层,根据实际计算调整输入特征数
        self.fc1 = nn.Linear(32 * 16 * 16, 512)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        # 通过第一个卷积层后接ReLU激活函数和池化层
        x = self.pool(nn.functional.relu(self.conv1(x)))
        # 通过第二个卷积层后接ReLU激活函数和池化层
        x = self.pool(nn.functional.relu(self.conv2(x)))
        # 展平特征图以输入到全连接层
        x = x.view(-1, 32 * 16 * 16)
        # 通过第一个全连接层后接ReLU激活函数
        x = nn.functional.relu(self.fc1(x))
        # 通过第二个全连接层得到最终输出
        x = self.fc2(x)
        return x


# 创建合成数据集实例
num_samples = 72
num_classes = 4
dataset = SyntheticDataset(num_samples, num_classes)

# 创建数据加载器
batch_size = 1
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 实例化模型
model = Simple3DCNN(num_classes)

# 打印模型结构
print(model)


# 损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练模型
def train_model(model, train_loader, criterion, optimizer, num_epochs=10):
    loss_list = []
    accuracy_list = []

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        running_corrects = 0

        for inputs, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            running_corrects += torch.sum(preds == labels.data)

        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_acc = running_corrects.double() / len(train_loader.dataset)
        loss_list.append(epoch_loss)
        accuracy_list.append(epoch_acc.item())

        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}')

    # 绘制损失和精度曲线
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.plot(loss_list, label='Loss')
    plt.title('Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(accuracy_list, label='Accuracy')
    plt.title('Training Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.tight_layout()
    plt.show()

# 调用训练函数
train_model(model, train_loader, criterion, optimizer, num_epochs=10)
Previous卷积神经网络(CNN)NextTransformer 原理

Last updated 8 months ago

Was this helpful?

🅰️
👽
2️