# Libtorch 基本模型

torch::nn::Module 的第一句注释：The design and implementation of this class is largely based on the Python API.

## 1. 数据准备

torch::Tensor w = torch::tensor({{1.0, 2.0}});
torch::Tensor x = torch::rand({20, 2});
torch::Tensor b = torch::randn({20, 1}) + 3;
torch::Tensor y = x.mm(w.t()) + b;

torch::Tensor img0 = torch::randn({10, 1, 28, 28}) * 100 + 100;
torch::Tensor label0 = torch::zeros({10}, torch::kLong);
torch::Tensor img1 = torch::randn({10, 1, 28, 28}) * 100 + 150;
torch::Tensor label1 = torch::ones({10}, torch::kLong);
torch::Tensor img = torch::cat({img0, img1});
torch::Tensor label = torch::cat({label0, label1});

## 2. 线性回归

torch::nn::Linear lin(2, 1);
torch::optim::SGD sgd(lin->parameters(), 0.1);
for (int i = 0; i < 10; i++) {
torch::Tensor y_ = lin(x);
torch::Tensor loss = torch::mse_loss(y_, y);
loss.backward();
sgd.step();
std::cout << "Epoch " << i << " loss=" << loss.item() << std::endl;
}

torch::nn::Linear lin_no_bias(torch::nn::LinearOptions(2,1).bias(false));

## 3. 多层感知机

(lin1): Linear(2, 4)
(relu): ReLU()
(lin2): Linear(4, 4)
(relu): ReLU()
(lin3): Linear(4, 1)

include 目录下创建两个文件 MLP.hMLP.cpp，创建 MLP 类并实现构建函数和 forward()（其实套路和 PyTorch 也差不多）。

// MLP.h
class MLP : public torch::nn::Module {
public:
MLP(int in_dim, int hidden_dim,int out_dim);
torch::Tensor forward(torch::Tensor x);

private:
torch::nn::Linear lin1{nullptr};
torch::nn::Linear lin2{nullptr};
torch::nn::Linear lin3{nullptr};
};

// MLP.cpp
MLP::MLP(int in_dim, int hidden_dim, int out_dim) {
lin1 = torch::nn::Linear(in_dim, hidden_dim);
lin2 = torch::nn::Linear(hidden_dim, hidden_dim);
lin3 = torch::nn::Linear(hidden_dim, out_dim);

register_module("lin1", lin1);
register_module("lin2", lin2);
register_module("lin3", lin3);
};

torch::Tensor MLP::forward(torch::Tensor x) {
x = lin1(x);
x = torch::relu(x);
x = lin2(x);
x = torch::relu(x);
x = lin3(x);
return x;
}

# chap4/include/CMakeLists.txt
target_link_libraries(libchap4 ${TORCH_LIBRARIES}) 主目录下面的 CMakeLists.txt 也需要做一些修改。 # chap4/CMakeLists.txt cmake_minimum_required(VERSION 3.21) project(BasicModels) find_package(Torch REQUIRED) add_subdirectory(include) add_executable(BasicModels BasicModels.cpp) target_link_libraries(BasicModels${TORCH_LIBRARIES} libchap4)

## 4. 卷积网络

(conv1): Conv2d(1, 16, kernel_size=(3, 3), padding=(1, 1))
(bn): BatchNorm2d(16)
(max_pool): MaxPool2d(2)
(relu): ReLU()
(conv2): Conv2d(16, 16, kernel_size=(3, 3), padding=(1, 1))
(bn): BatchNorm2d(16)
(max_pool): MaxPool2d(2)
(relu): ReLU()
(lin): Linear(783, 2)

// CNN.h
class CNN : public torch::nn::Module {
public:
CNN(int num_classes);
torch::Tensor forward(torch::Tensor x);

private:
torch::nn::Conv2d conv1{nullptr};
torch::nn::Conv2d conv2{nullptr};
torch::nn::ReLU relu{nullptr};
torch::nn::MaxPool2d max_pool{nullptr};
torch::nn::BatchNorm2d bn{nullptr};
torch::nn::Linear lin{nullptr};
};
// CNN.cpp
CNN::CNN(int num_classes) {
bn = torch::nn::BatchNorm2d(16);
relu = torch::nn::ReLU();
max_pool = torch::nn::MaxPool2d(2);
lin = torch::nn::Linear(7 * 7 * 16, num_classes);

register_module("conv1", conv1);
register_module("conv2", conv2);
register_module("bn", bn);
register_module("relu", relu);
register_module("max_pool", max_pool);
register_module("lin", lin);
}

torch::Tensor CNN::forward(torch::Tensor x) {
x = conv1(x);
x = bn(x);
x = relu(x);
x = max_pool2d(x, 2);

x = conv2(x);
x = bn(x);
x = relu(x);
x = max_pool2d(x, 2);
x = lin(x.reshape({x.size(0), -1}));

return x;
}

Netty源码剖析与实战 -〔傅健〕

To B市场品牌实战课 -〔曹林〕

PyTorch深度学习实战 -〔方远〕