【深度学习02】 多变量线性回归

【深度学习02】 多变量线性回归,第1张

文章目录
    • 基本原理
    • 梯度下降
      • 1.步骤
      • 2.小批量梯度下降
    • 线性回归实现
      • 1.生成数据集
      • 2.数据可视化
      • 3.读取数据
      • 4.定义模型
      • 5.初始化参数
      • 6.定义损失函数
      • 7.定义优化算法
      • 8.训练模型

⭐本文内容:多变量线性回归数学推导,梯度下降,基于Pytorch的代码实现

💌参考链接:

ML1 单变量线性回归_什么都只会一点的博客-CSDN博客

dl3.2线性回归 | Kaggle

3.3. 线性回归的简洁实现 — 动手学深度学习 2.0.0-beta0 documentation (d2l.ai)

基本原理

x = [ x 1 , x 2 , x 3 , x 4 , ⋯ ⋯   , x n ] ⊤ w = [ w 1 , w 2 , ⋯ ⋯   , w n ] ⊤ y = w 1 x 1 + w 2 x 2 + ⋯ ⋯ + w n x n + b y = ⟨ w , x ⟩ + b D ( y ) = 1 2 ( y − y ^ ) 2 l ( x , y , w , b ) = 1 2 n ∑ i = 1 n ( y i − y ^ ) 2 = 1 2 n ∑ i = 1 n ( y i − ⟨ w , x ⟩ − b ) 2 l ( x , y , w , b ) → min ⁡ → w ∗ , b ∗ \begin{array}{l} x=\left[x_{1}, x_{2}, x_{3}, x_{4} ,\cdots \cdots, x_{n}\right]^{\top}\ w=\left[w_{1}, w_{2}, \cdots \cdots, w_{n}\right]^{\top}\ y=w_{1} x_{1}+w_{2} x_{2}+\cdots \cdots+w_{n} x_{n}+b\ y=\langle w, x\rangle+b\ D(y)=\frac{1}{2}(y-\hat{y})^{2}\ l(x, y, w, b)=\frac{1}{2 n} \sum_{i=1}^{n}\left(y_{i}-\hat{y} \right)^{2}\ =\frac{1}{2 n} \sum_{i=1}^{n}\left(y_{i}-\langle w, x\rangle-b\right)^{2}\ l(x, y, w, b) \rightarrow \min\ \rightarrow w^{*}, b^{*} \end{array} x=[x1,x2,x3,x4,,xn]w=[w1,w2,,wn]y=w1x1+w2x2++wnxn+by=w,x+bD(y)=21(yy^)2l(x,y,w,b)=2n1i=1n(yiy^)2=2n1i=1n(yiw,xb)2l(x,y,w,b)minw,b

梯度下降 1.步骤
  • 挑选一个初始值 w 0 w_0 w0

  • 重复迭代参数:t= 1,2,3

    w t = w t − 1 − η ∂ l ∂ w t − 1 \mathbf{w}_{t}=\mathbf{w}_{t-1}-\eta \frac{\partial l}{\partial \mathbf{w}_{t-1}} wt=wt1ηwt1l

  • 沿梯度方向将增加损失函数值

    w 0 − w 1 : η ∂ l ∂ w t − 1 w_0-w_1:\eta \frac{\partial l}{\partial \mathbf{w}_{t-1}} w0w1ηwt1l

  • 🍔学习率 n:步长的参数

2.小批量梯度下降

小批量梯度下降是深度学习默认的求解方法

为了节约训练的时间和数据,我们可以随机抽取 b 个样本 i 1 , i 2 , i 3 , … … … i b i_1,i_2,i_3,………i_b i1,i2,i3,ib以便近似损失
1 b ∑ i ∈ I b l 2 ( x i , y i , w ) \frac{1}{b} \sum_{i \in I_{b}} l^{2}\left(\mathbf{x}_{i}, y_{i}, \mathbf{w}\right) b1iIbl2(xi,yi,w)

  • 🍟b 是选取的训练样本多少,不宜过大或过小
线性回归实现
import torch
import random
import numpy as np
import matplotlib_inline.backend_inline
from matplotlib import pyplot as plt
1.生成数据集

定义权重和偏差

num_inputs = 2  #输入两个参数:w、b
num_examples = 1000 #样本个数
true_w = [2, -3.4]
true_b = 4.2

随机生成正态分布的输入 fatures,feature包括两个影响因素(参数),w和b

features = torch.tensor(np.random.normal(0, 1, (num_examples,num_inputs)), dtype=torch.float)

根据输入 features ,生成输出 labels

labels = true_w[0] * features[:, 0] + true_w[1] * features[:, 1] +true_b
labels = labels+torch.tensor(np.random.normal(0, 0.01,size = labels.size()), dtype=torch.float)  #在真实数据的基础上,加上噪声
2.数据可视化
def use_svg_display():
    # 用⽮量图显示
    matplotlib_inline.backend_inline.set_matplotlib_formats('svg')
    
def set_figsize(figsize=(3.5, 2.5)):
    use_svg_display()
    # 设置图的尺⼨
    plt.rcParams['figure.figsize'] = figsize
    

set_figsize()
plt.scatter(features[:, 1].numpy(), labels.numpy(), 1); 

3.读取数据

PyTorch提供了了 data 包来读取数据

batch_size = 10,随机读取包含10个数据样本的小批量

import torch.utils.data as Data

batch_size = 10

将训练数据的特征和标签组合

features,labels
(tensor([[ 2.1282, -0.8920],
         [-1.0093,  0.2924],
         [ 1.5258, -0.5783],
         ...,
         [-1.6918, -0.7286],
         [ 0.2524,  1.0183],
         [-0.7605, -1.3882]]),
 tensor([ 1.1486e+01,  1.1739e+00,  9.2210e+00, -3.1652e+00,  1.3957e+00,
          1.4540e+00,  1.0448e+01, -3.0799e+00, -6.0249e+00,  6.2256e+00,
          1.8337e+00, -2.2532e+00,  9.8570e+00,  4.3742e+00,  1.9240e+00,
          9.9293e+00,  4.0584e+00,  8.3627e+00,  1.0823e+00,  7.2722e-01,
         -7.6379e-01,  5.2927e+00,  6.2904e+00,  5.2404e+00,  3.8104e+00,
          6.5634e+00,  4.5559e-01,  1.1481e+00,  6.5535e-01,  5.5602e+00,
         -4.5899e-01, -4.4074e+00,  3.7406e+00,  7.3968e+00,  3.1700e+00,
          8.9007e+00,  9.8127e+00,  1.3216e+00,  6.4583e+00,  8.1877e+00,
          1.1374e+00,  1.4746e+00,  6.3436e+00,  3.4523e+00,  7.4513e+00,
         -5.2496e+00,  2.1701e+00,  6.2465e+00,  3.6356e-02,  3.9021e+00,
          9.7487e+00,  1.1606e+01, -5.7617e+00,  1.1398e+00,  2.4662e+00,
          6.8591e+00,  2.6556e+00,  1.3848e+00,  8.7197e+00,  8.0591e+00,
          3.8578e+00, -3.5013e+00,  2.3531e+00,  5.4477e+00,  1.1095e+00,
          3.5379e+00,  4.1187e+00,  5.8690e+00,  8.4104e-01,  3.7816e+00,
          3.0125e+00,  2.8499e+00,  6.8824e+00, -1.2200e+00, -4.0836e-01,
          7.7109e-01, -2.9056e+00,  5.3143e+00,  5.3520e+00,  8.4585e+00,
          6.2324e+00,  1.2647e+01,  7.2575e+00,  6.2891e+00,  4.9514e+00,
          6.2109e+00,  5.0092e+00, -9.9337e-01,  5.1607e+00,  6.7718e+00,
          9.1802e+00,  6.9370e+00,  2.4028e+00,  4.2869e+00,  3.7471e+00,
          8.8147e+00,  6.0244e+00,  6.3032e+00,  2.6607e+00,  3.3901e+00,
          6.7301e+00,  5.5077e+00,  1.7717e+00,  6.5372e-01,  1.0034e+00,
          2.8830e+00,  1.1419e+01,  1.1749e+00,  1.0106e-02,  8.7469e+00,
          7.0206e+00,  1.5480e+00,  3.8976e+00,  4.8662e+00, -1.1577e+00,
         -3.9768e+00,  2.6253e+00,  5.3503e+00,  4.5756e+00,  1.3480e+00,
          5.1243e+00,  8.4004e+00,  3.4222e+00,  6.4114e+00,  5.3183e-01,
          9.8974e+00,  2.7052e+00,  3.0212e+00,  2.8368e+00,  8.3128e+00,
          4.6546e+00, -4.9879e+00, -4.3837e+00,  2.0918e+00,  4.9817e+00,
         -2.1596e+00,  1.6440e+00, -2.1060e+00,  6.8430e+00,  2.1973e+00,
          4.6677e+00,  5.2363e+00, -3.8503e+00,  9.0410e+00,  1.0080e+01,
          5.9184e+00,  1.8692e+01,  4.9923e+00,  6.0264e+00,  8.2590e-01,
          8.0111e+00,  1.3049e+01,  1.0785e+01,  7.4764e+00,  3.9071e+00,
          1.5800e+00,  7.7218e-01, -9.2418e+00,  4.4402e+00,  5.6216e+00,
          6.0654e+00,  3.6753e+00,  1.1367e+01,  7.0856e+00,  2.8123e+00,
          3.8476e+00,  5.7843e+00, -2.6613e+00,  8.5492e+00,  1.0268e+01,
          3.4798e+00,  6.6199e+00,  1.2376e+01,  5.1228e+00,  1.3886e+00,
          2.0064e+00, -2.6576e+00,  6.8330e+00,  3.6461e+00,  6.6444e+00,
          1.3514e+01,  6.6098e+00,  1.6626e+00,  7.5031e+00,  3.7306e+00,
          4.1361e+00, -1.0839e+00,  4.7839e+00,  1.1011e+00,  9.5930e+00,
          9.0655e+00,  6.2866e+00, -3.4704e-02,  2.1236e+00,  3.6567e+00,
          1.1865e+01,  5.2644e+00,  3.9271e+00, -3.2964e+00, -1.0615e+00,
          7.9056e+00,  3.3556e+00,  6.9934e+00,  3.8433e+00,  5.2684e+00,
          1.0572e+01,  5.6345e+00,  5.2405e+00,  7.0610e+00,  8.5568e+00,
          9.2425e+00,  3.1511e+00,  1.2075e+01, -1.6131e+00,  2.7143e+00,
          1.2283e+01,  2.9278e+00,  1.6995e+00,  4.1698e+00,  6.1994e+00,
         -2.4863e+00,  1.0766e+01,  8.6130e+00, -1.4727e+00,  9.1569e+00,
          2.3191e+00,  9.3571e+00,  3.9649e+00,  6.5649e+00,  5.1595e-01,
          4.0500e+00,  4.2225e+00,  8.2217e+00, -6.3868e-01,  7.4402e+00,
         -2.4692e+00, -3.7820e+00,  1.8397e+00, -3.7670e-01, -2.8920e+00,
          3.5418e+00, -1.1415e+00,  8.7870e+00,  4.9213e+00,  9.9247e+00,
          6.6192e+00,  1.7358e+00, -3.6307e-01, -3.6719e+00,  4.1485e+00,
          5.8974e+00,  5.5616e+00,  8.7071e+00,  6.3648e+00,  1.2180e+01,
          1.5238e+00,  7.0139e+00,  2.3830e+00,  4.3800e-03, -1.9390e+00,
          4.2378e+00,  1.9782e+00,  7.2340e+00,  6.6344e+00,  3.1206e+00,
          9.1773e+00,  1.0010e+00,  6.2425e-01,  2.2324e+00,  1.4394e+00,
         -1.0583e+00,  1.1862e+00,  3.3771e+00,  6.6086e+00,  1.5868e+00,
          3.7503e+00,  1.9657e+00,  5.8942e+00,  5.5674e+00,  9.7752e+00,
          3.9570e+00,  9.6817e+00,  5.5203e+00,  3.2171e+00,  2.0135e+00,
          1.9813e+00,  7.1679e+00,  1.2223e+00,  6.4231e+00,  3.3712e+00,
          6.1226e+00,  7.3023e+00,  7.1456e+00,  1.1152e+00,  4.3584e+00,
         -2.1064e+00,  2.5603e+00,  9.0358e+00,  1.0727e+01,  5.5608e+00,
          2.2114e+00,  1.9765e+00,  5.3305e+00,  1.1187e+00,  4.9850e+00,
          6.4261e+00,  5.4080e+00,  1.8553e+00,  5.4068e+00,  1.3708e+01,
          2.3931e+00,  4.4118e+00,  3.2059e+00, -3.6586e-01,  1.8181e+00,
         -3.6866e-02,  4.8933e+00,  5.8703e+00,  8.9500e+00,  5.1381e+00,
          5.0268e+00,  6.1895e+00,  2.0560e+01,  2.9177e+00, -1.1442e+00,
         -4.0448e+00,  3.6746e+00,  4.4213e+00,  3.9709e+00,  1.0074e+01,
          7.8624e+00,  4.4473e+00,  3.1089e+00, -2.1261e+00,  4.5493e+00,
          6.1902e+00,  3.9441e-01,  6.3312e+00,  2.6792e+00,  2.2669e+00,
          5.3156e+00,  3.5936e-01,  3.0885e+00,  4.6025e+00,  3.0167e+00,
         -4.9812e-01, -5.2459e-02,  8.6781e+00,  6.2955e+00,  1.1866e+01,
          3.0411e+00,  1.7559e-01,  5.5648e+00,  6.2147e+00,  7.5455e+00,
          9.4367e+00,  8.8933e+00,  2.5610e+00,  4.3828e+00,  3.7813e+00,
          2.0521e+00,  2.8338e+00,  1.0147e+00,  1.6295e+00,  1.7991e+00,
          4.6832e+00,  5.6675e+00, -4.6209e-01,  3.3670e+00,  2.1762e+00,
          5.9204e+00,  6.4009e+00,  3.7618e+00,  1.2200e+01,  2.5451e+00,
          3.5890e-01, -4.5513e-01,  5.3301e+00, -2.5259e-01,  6.0445e+00,
          7.6773e+00, -8.3305e+00,  5.9668e+00,  5.1585e+00,  1.4224e+00,
          7.0593e+00, -7.1631e-01,  6.0731e+00,  5.7953e+00, -1.9298e-01,
          2.0513e+00,  2.5159e+00,  7.1324e+00,  1.2148e+00,  2.3089e+00,
          1.5061e+01,  1.3625e+00,  2.6933e+00,  6.7279e+00, -2.5925e+00,
          7.0981e+00,  4.9872e+00,  4.2145e+00,  8.7688e+00, -2.2502e+00,
         -3.1313e-01,  1.3205e+01,  2.0962e+00, -8.5085e-01,  2.3345e+00,
          5.9682e+00,  1.8128e+00,  1.2720e+00,  1.1814e+01,  1.0851e+01,
          1.6860e+00,  7.1576e+00,  5.5403e+00,  1.5703e+01, -2.4011e-01,
          4.8863e+00,  9.8978e+00,  2.8947e+00,  4.4821e+00,  3.8706e+00,
          8.4014e+00,  2.9715e+00,  6.0662e+00,  2.5023e+00, -3.0358e+00,
         -2.4967e+00,  5.8111e+00,  2.3611e+00, -3.1542e+00,  9.3606e-01,
          3.6221e+00,  2.1523e+00,  3.6833e+00,  8.0808e+00, -1.1753e+00,
          1.1207e+01, -6.3684e-01, -2.1802e+00,  1.9406e+00,  6.2395e+00,
          7.5898e+00, -3.3976e-01,  1.9477e+00,  3.8777e+00,  3.8332e+00,
         -6.8939e-01,  6.6390e+00,  4.8504e+00,  6.6808e+00, -1.1883e+00,
          3.3064e+00, -2.6315e-01,  4.2325e+00,  1.6865e+00,  3.3041e+00,
          6.7421e+00,  5.4875e+00,  1.1149e+00, -2.7828e-01,  5.4725e+00,
         -7.7733e-01,  4.5353e+00,  1.4220e+00,  1.9443e+00,  7.7243e-01,
          9.0265e+00,  9.1433e+00,  8.8006e+00,  3.8516e+00,  9.9867e+00,
         -8.8891e-01,  3.2137e+00,  6.7171e+00,  6.9532e+00,  3.3229e+00,
          1.1517e+01,  5.6163e+00,  3.1043e+00,  7.4874e+00,  3.2463e+00,
          7.3695e-01,  4.0015e+00,  8.1072e+00,  8.2856e+00, -7.8871e-01,
          3.3036e+00,  4.1132e+00,  1.2551e+00,  4.6967e+00,  5.6785e+00,
         -4.4015e+00,  1.0862e+01,  7.1842e+00,  8.8019e-01,  8.8348e-01,
         -2.2253e+00,  3.9538e+00,  4.5217e+00,  6.6754e+00,  4.9789e+00,
          3.1611e+00,  1.6457e+00,  1.2445e+00,  9.4860e+00,  3.2143e+00,
          3.6886e+00,  3.4686e+00,  9.2762e+00,  4.2195e+00,  4.7956e+00,
          7.1322e+00,  9.2499e+00,  5.1920e+00,  9.0004e+00,  4.3371e+00,
          7.6478e-01,  6.4798e+00,  4.1674e+00,  2.8258e-01,  1.6894e+00,
          2.6807e+00,  3.8514e+00,  8.2753e+00,  1.9894e+00,  5.3953e+00,
          6.0704e+00,  7.1677e+00,  1.0194e+01,  8.0606e+00,  1.0083e+00,
          7.1323e+00,  1.4029e+00,  2.9510e+00,  1.6452e+00,  2.8926e+00,
         -4.4551e+00,  2.2231e+00,  6.1967e+00, -7.6643e-01,  5.5426e+00,
          7.4233e+00,  4.7071e+00, -3.0313e-01,  6.5572e+00,  6.1507e+00,
         -3.8997e-01,  1.8177e+00,  4.6363e+00,  6.8179e+00, -4.7610e+00,
          6.2758e+00,  1.1724e+01,  3.0298e+00,  6.9504e+00,  5.0827e+00,
          6.1753e+00,  3.7896e+00,  4.8787e+00,  1.4354e+00,  4.0417e+00,
          8.5842e+00,  2.6213e+00,  5.0641e+00,  5.3131e-02,  1.2046e+01,
          1.4526e+01,  4.0456e+00,  1.9023e-01, -6.0494e-01,  6.0688e+00,
          4.5011e+00,  1.1216e+01,  3.7415e-01,  5.9470e+00,  8.2929e+00,
          5.5519e+00,  3.3750e+00,  4.3696e+00,  3.8813e+00,  7.8028e+00,
         -6.4641e+00, -9.2145e-01,  2.0501e+00,  2.1780e+00, -3.8143e+00,
          3.5393e+00,  3.0591e+00,  3.6973e+00,  4.7163e+00,  7.1230e+00,
         -7.3374e-01, -1.9185e+00,  6.5769e+00,  2.8983e+00, -6.3834e-01,
          4.6760e+00,  6.7897e+00,  2.0130e+00,  9.5432e+00,  6.0370e+00,
          3.0305e+00,  6.5129e+00, -6.3260e-01,  6.3626e+00,  7.2939e+00,
          9.9977e+00,  9.9677e+00,  6.6315e+00, -1.6937e+00, -7.7215e-01,
          3.6396e+00,  7.9828e-01,  8.3489e+00,  1.4208e+00,  8.0603e-01,
          1.1263e+00, -1.0269e+00,  6.2334e+00,  2.4863e+00,  6.9147e+00,
          9.9371e+00,  1.0550e+00,  7.1006e+00, -1.2406e+00,  2.1380e+00,
          4.1862e+00,  6.3072e+00,  6.5124e+00,  1.2564e+01, -2.1824e+00,
          9.7297e+00,  3.1217e+00,  2.9313e+00,  6.0977e+00,  5.8781e-01,
          5.3255e+00,  9.1447e-01,  8.6587e+00, -4.1399e+00,  3.8851e+00,
          1.1139e+01,  2.9723e+00,  9.2591e+00,  7.2289e+00, -4.3935e+00,
          4.1558e+00,  9.8247e+00,  1.0492e+01,  8.8529e+00,  3.4140e+00,
          1.0711e+01,  1.2701e+01,  1.4397e+00,  1.4085e+00,  4.3338e+00,
          2.6513e+00,  2.3139e+00,  6.7066e+00,  1.1019e+01,  4.3761e+00,
          2.9457e+00, -1.6806e+00,  1.1144e+01,  6.7122e+00,  2.5283e+00,
          3.9789e+00,  3.0566e-01,  3.1576e+00, -1.0638e+00,  5.3148e+00,
          8.6122e-01,  9.8231e+00, -1.7613e-01,  7.0995e+00,  6.3133e+00,
          4.8049e+00, -1.4338e+00,  2.9295e-01,  7.9991e+00,  4.0050e+00,
          4.4134e+00,  9.9134e+00,  6.9708e+00,  4.1459e+00,  5.1887e+00,
          7.1081e+00, -3.4246e+00,  7.5349e+00,  3.7021e+00,  3.6239e+00,
          1.0308e+01,  4.3593e+00,  6.3187e+00,  6.4794e+00,  1.4713e+01,
         -2.8738e+00,  5.1832e+00,  6.5156e+00,  1.1308e+00,  3.0447e+00,
          7.3644e-01,  1.3814e+00,  1.0465e+01,  1.3175e+01, -1.2418e+00,
          3.7152e+00,  4.1893e+00,  8.6820e-01,  7.6717e+00,  4.2367e+00,
          5.7733e+00,  4.8947e+00,  8.6885e+00,  5.7364e+00,  5.9605e+00,
          4.0355e+00,  2.8231e+00,  4.3217e+00, -5.2753e+00,  7.0826e+00,
          9.9783e+00, -1.8532e+00,  7.9443e+00,  3.8591e+00,  9.1033e+00,
         -3.5306e+00,  5.0788e+00,  8.3283e+00,  3.4355e+00,  6.3829e+00,
          5.2743e-01, -1.2355e+00,  5.8636e+00,  2.2426e+00,  1.7321e+00,
          1.2331e+00,  8.1444e+00,  9.1087e-01,  6.5348e+00,  8.8130e+00,
          2.4073e+00,  9.1603e+00,  5.4269e-02,  8.6125e-01,  8.5849e+00,
          5.8992e+00,  1.0239e+00,  5.5160e-01,  9.0970e+00,  7.4586e+00,
          4.4703e+00,  7.9342e+00, -6.7836e+00,  3.3555e+00,  5.2390e+00,
          1.1013e+00,  3.4443e+00,  1.3650e+01,  4.1885e+00,  3.4643e+00,
          4.0415e+00,  7.5722e+00,  5.1383e+00,  2.3401e+00,  5.2583e+00,
          4.1095e+00,  8.1516e-01,  4.7037e+00,  2.3122e-02, -1.5872e+00,
          2.2099e+00,  5.2872e+00,  8.3151e+00,  1.6701e+00,  3.4691e+00,
          4.7135e+00,  7.3734e+00, -2.1556e+00,  1.1943e+01, -4.8788e-01,
          2.6438e+00,  7.0337e+00,  6.4599e+00,  4.1977e+00,  6.7508e+00,
         -5.6997e+00,  7.5895e+00, -5.8350e+00,  2.6028e+00,  2.4710e+00,
          2.9316e+00,  5.5433e+00,  3.2023e+00,  6.8027e+00,  2.1754e+00,
          5.1678e+00,  7.6736e+00,  3.1644e-01,  7.4585e+00,  6.7207e+00,
          2.4918e+00,  5.2083e+00,  4.3911e+00,  2.1199e+00,  7.3935e+00,
          3.0019e+00,  5.1438e+00,  1.6835e+00,  1.7025e+01,  6.4380e+00,
          3.9269e+00,  1.3423e+00,  1.0056e+01,  6.4518e+00,  1.1786e+01,
          1.4369e+01,  5.9835e+00, -4.0346e+00,  6.7038e+00,  6.3128e+00,
          2.5318e+00,  5.7233e-01, -3.8579e+00,  3.5814e+00,  3.6234e+00,
          3.7589e+00,  3.2349e+00,  9.7685e+00, -5.3370e+00,  3.0212e-01,
          7.4470e+00,  7.9035e+00,  3.3017e+00,  4.4283e+00, -1.1999e+00,
          3.4110e+00, -7.9651e-01,  3.2304e+00, -3.5655e+00, -5.1933e+00,
          1.2617e+01,  3.7052e+00,  6.2573e+00,  5.5877e+00, -5.4492e-01,
          4.6039e+00,  1.4071e+00,  1.0263e+01,  1.5162e+00,  1.0058e+01,
          1.9859e+00,  6.4663e+00,  5.1987e-01,  6.8607e+00,  6.7446e+00,
          1.7579e+00, -1.7222e+00, -1.0771e+00,  7.7494e-01, -1.9333e-02,
          8.8316e+00, -4.7399e+00,  4.9970e+00,  9.0686e+00, -6.3899e-01,
          3.5481e+00,  7.5947e+00,  4.1835e+00,  4.8132e+00,  5.4006e+00,
          7.1721e+00,  1.3880e+01,  8.8979e+00,  5.6534e+00,  2.4463e+00,
         -8.5860e+00,  5.2726e+00, -1.9059e+00,  9.7225e+00,  6.1324e+00,
          2.7480e+00,  9.5607e+00, -5.1262e-01, -5.2414e-01,  5.8667e+00,
          5.3955e+00,  7.5670e+00,  7.3300e+00,  8.9820e+00,  4.0835e+00,
          7.5790e+00,  8.7937e+00,  4.8071e+00,  3.3220e+00,  1.1934e+01,
          9.3510e+00,  5.2204e+00,  4.7458e+00,  9.7410e+00,  5.5066e-01,
          2.2319e+00,  4.5363e+00,  8.8502e+00, -2.3816e+00,  1.1200e+01,
          8.4412e+00,  6.6048e+00,  1.0004e+00,  8.3695e+00,  7.5421e+00,
          5.3460e-01,  9.5623e+00,  3.5176e+00,  1.5785e+00,  6.8748e+00,
         -2.5705e-01,  3.8226e+00,  1.2004e+01,  5.6965e+00,  5.1880e+00,
          9.5765e+00,  7.1425e+00,  3.8174e+00,  2.7385e+00,  1.5513e+00,
         -2.4735e+00,  7.1869e+00,  2.1463e+00,  2.9772e+00, -2.1224e+00,
          6.4781e+00,  3.9532e-01, -2.8792e+00, -1.8212e+00,  1.3565e+01,
          8.5305e+00,  7.3443e+00,  3.3902e+00,  7.2337e+00,  8.1134e+00,
          1.9861e+00,  7.0784e-01,  5.0915e+00,  2.6137e+00,  2.9595e+00,
          6.4367e+00,  4.6851e+00, -3.0581e+00,  6.5745e+00, -1.3078e+00,
          2.3855e+00,  6.5987e+00,  3.0967e+00,  3.3623e+00,  1.5926e+00,
          9.8759e+00,  3.0744e+00,  2.4427e-01,  6.9704e+00,  3.6868e+00,
         -3.4440e+00,  4.8641e+00,  6.7563e+00,  8.7857e+00,  7.6898e+00,
          1.3370e+01,  3.7387e+00, -2.7800e+00, -4.4384e+00,  4.4311e+00,
          4.6384e+00,  2.8387e+00,  4.7712e+00,  4.1730e+00,  2.1846e+00,
          3.8315e+00,  7.0497e+00, -1.2256e+00,  4.5546e+00,  3.3716e+00,
          1.2256e-01,  5.6450e+00,  9.6056e+00, -5.0523e-01,  2.6049e+00,
         -7.1354e-01,  5.1554e+00,  5.3114e+00,  8.1763e+00, -9.1869e+00,
         -2.5365e+00,  4.2363e+00,  3.2928e+00,  1.2413e+00,  7.4122e+00]))

将训练数据的特征和标签组合

dataset = Data.TensorDataset(features,labels)
dataset

随机读取小批量

data_iter = Data.DataLoader(dataset, batch_size, shuffle=True)
data_iter

打印第一个小批量样本数据

for x, y in data_iter:
    print(x, y)
    break
tensor([[ 1.8036,  0.2046],
        [-0.1144,  1.2728],
        [ 0.4237, -0.0737],
        [ 0.9859, -0.9073],
        [-1.9821,  2.7871],
        [-0.7970, -1.2866],
        [ 0.1646, -0.6856],
        [ 0.7002,  1.0762],
        [-0.9576, -1.3074],
        [ 0.7070, -0.1870]]) tensor([ 7.1322, -0.3659,  5.2927,  9.2499, -9.2418,  6.9704,  6.8748,  1.9477,
         6.7563,  6.2395])
4.定义模型

导⼊ torch.nn 模块

import torch.nn as nn

nn.Sequential 来更加⽅便地搭建网络, Sequential 是一个有序的容器,网络
层将按照在传入 Sequential 的顺序依次被添加到计算图中

net = nn.Sequential(
    nn.Linear(num_inputs, 1)
    # 此处还可以传⼊入其他层
    )

print(net)
print(net[0])
Sequential(
  (0): Linear(in_features=2, out_features=1, bias=True)
)
Linear(in_features=2, out_features=1, bias=True)

可以通过 net.parameters() 来查看模型所有的可学习参数,此函数将返回一个生成器

for param in net.parameters():
    print(param)
Parameter containing:
tensor([[0.6623, 0.0845]], requires_grad=True)
Parameter containing:
tensor([0.1282], requires_grad=True)
5.初始化参数
from torch.nn import init

我们通过 init.normal_ 将权重参数每个元素初始化为随机采样于均值为0、标准差为0.01的正态分布。偏差会初始化为零。

init.normal_(net[0].weight,mean=0,std=0.01)
init.constant_(net[0].bias,val=0)  #也可以直接修改bias的data
net[0].bias.data.fill_(0)
tensor([0.])
6.定义损失函数

使用均方误差损失作为模型的损失函数

loss = nn.MSELoss()
7.定义优化算法

指定学习率为0.03的⼩小批量量随机梯度下降(SGD)为优化算法

import torch.optim as optim

optimizer = optim.SGD(
    #如果对某个参数不指定学习率,就使用最外层的默认学习率
    net.parameters(),lr=0.03
    )
print(optimizer)
SGD (
Parameter Group 0
    dampening: 0
    lr: 0.03
    momentum: 0
    nesterov: False
    weight_decay: 0
)

构建新的optimizer,动态调整学习率

for param_group in optimizer.param_groups:
    param_group['lr'] *= 0.1  #学习率为之前的0.1倍
8.训练模型
num_epochs = 10  #迭代的次数,次数越大最后的准确率越高
for epoch in range(1,num_epochs+1):
    for x,y in data_iter:
        output = net(x)
        l = loss(output,y.view(-1,1))
        optimizer.zero_grad()  #梯度清零,等价于net.zero_grad()
        l.backward()
        optimizer.step()
    print('epoch: %d,loss: %f' %(epoch,l.item()))
epoch: 1,loss: 15.321363
epoch: 2,loss: 6.421220
epoch: 3,loss: 0.666509
epoch: 4,loss: 0.269027
epoch: 5,loss: 0.071433
epoch: 6,loss: 0.027302
epoch: 7,loss: 0.005877
epoch: 8,loss: 0.001547
epoch: 9,loss: 0.000483
epoch: 10,loss: 0.000219

训练值和真实值对比

dense = net[0]
print(true_w, dense.weight)
print(true_b, dense.bias)
[2, -3.4] Parameter containing:
tensor([[ 1.9972, -3.3939]], requires_grad=True)
4.2 Parameter containing:
tensor([4.1887], requires_grad=True)

欢迎分享,转载请注明来源:内存溢出

原文地址: https://www.outofmemory.cn/langs/892567.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-05-14
下一篇 2022-05-14

发表评论

登录后才能评论

评论列表(0条)

保存