PyTorch学习笔记:F.normalize——数组归一化运算

PyTorch学习笔记:F.normalize——数组归一化运算

torch.nn.functional.normalize(input, p=2.0, dim=1, eps=1e-12, out=None)

功能:利用

L

p

L_p

Lp​范数对输入的数组沿特定的维度进行归一化

  对于尺寸为

(

n

0

,

,

n

d

i

m

,

,

n

k

)

(n_0,\dots,n_{dim},\dots,n_k)

(n0​,…,ndim​,…,nk​)的输入数组input,每个

n

d

i

m

n_{dim}

ndim​上的元素向量

v

v

v沿着维度dim进行转换,转换公式为:

v

=

v

max

(

v

p

,

ϵ

)

v=\frac{v}{\max(||v||_p,\epsilon)}

v=max(∣∣v∣∣p​,ϵ)v​

范数计算公式

对于数据

x

=

[

x

1

,

x

2

,

,

x

n

]

T

x=[x_1,x_2,\dots,x_n]^T

x=[x1​,x2​,…,xn​]T:

  • L

    p

    L_p

    Lp​范数:

    x

    p

    =

    (

    x

    1

    p

    +

    x

    2

    p

    +

    +

    x

    n

    p

    )

    1

    p

    ||x||_p=(|x_1|^p+|x_2|^p+\dots+|x_n|^p)^{\frac1p}

    ∣∣x∣∣p​=(∣x1​∣p+∣x2​∣p+⋯+∣xn​∣p)p1​

  • L

    1

    L_1

    L1​范数:

    x

    1

    =

    x

    1

    +

    x

    2

    +

    +

    x

    n

    ||x||_1=|x_1|+|x_2|+\dots+|x_n|

    ∣∣x∣∣1​=∣x1​∣+∣x2​∣+⋯+∣xn​∣

  • L

    2

    L_2

    L2​范数:

    x

    2

    =

    (

    x

    1

    2

    +

    x

    2

    2

    +

    +

    x

    n

    2

    )

    1

    2

    ||x||_2=(|x_1|^2+|x_2|^2+\dots+|x_n|^2)^{\frac12}

    ∣∣x∣∣2​=(∣x1​∣2+∣x2​∣2+⋯+∣xn​∣2)21​

输入:

  • input:输入的数组,数组数据类型为float
  • p:指定使用的范数,数据类型为float,默认2.0
  • dim:指定的维度,数据类型为int,默认1
  • eps:边界值,防止分母为0,默认1e-12

代码案例

一般用法

import torch.nn.functional as F
import torch

a = torch.arange(20, dtype=torch.float).reshape(4,5)
b = F.normalize(a, dim=0)
c = F.normalize(a, dim=1)
print(a)
print(b)
print(c)

输出

# 输入的数组
tensor([[ 0.,  1.,  2.,  3.,  4.],
        [ 5.,  6.,  7.,  8.,  9.],
        [10., 11., 12., 13., 14.],
        [15., 16., 17., 18., 19.]])
# dim=0时,即沿第一维度(列)做归一化
tensor([[0.0000, 0.0491, 0.0907, 0.1261, 0.1564],
        [0.2673, 0.2949, 0.3175, 0.3363, 0.3519],
        [0.5345, 0.5406, 0.5443, 0.5464, 0.5474],
        [0.8018, 0.7864, 0.7711, 0.7566, 0.7430]])
# dim=1时,即沿第二维度(行)做归一化
# 维度记忆技巧:最后一个维度始终是行,从后向前推:行、列、通道
tensor([[0.0000, 0.1826, 0.3651, 0.5477, 0.7303],
        [0.3131, 0.3757, 0.4384, 0.5010, 0.5636],
        [0.3701, 0.4071, 0.4441, 0.4812, 0.5182],
        [0.3932, 0.4195, 0.4457, 0.4719, 0.4981]])

官方文档

torch.nn.functional.normalize:https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html#torch.nn.functional.normalize

初步完稿于:2022年2月6日

本文来自网络,不代表协通编程立场,如若转载,请注明出处:https://www.net2asp.com/1cbb99d2e7.html