Sorry, your browser cannot access this site
This page requires browser support (enable) JavaScript
Learn more >

本文记录了最大均值差异的相关原理,并给出了使用Python实现的源码。

最大均值差异

概念

最大均值差异(Maximum Mean Discrepancy,MMD)是迁移学习,尤其是域适应(Domain Adaptation)中使用最广泛的一种损失函数,其本质是度量在再生希尔伯特空间中两个分布的距离。简而言之是用于判断两个分布是否相同。

与KL散度的不同

KL散度尽管和MMD都可以衡量两个样本之间的差异,但本质上,KL散度度量的是两者之间的信息损失,而不是两者之间的距离。

两者作为损失函数,个人理解在本质上是没有什么差别的,都是为了衡量两者之间的分布关系。只是其自身的一些局限性或者是优势,得以在不同的领域运用。

公式

MMD的相关理解较为复杂,此处只给出其经验评估。给定两个数据集$D_s = (x_1,x_2,\cdots,x_n) \sim P(x)$和$D_t=(y_1,y_2,\cdots,y_m)\sim Q(y)$,则MMD的经验评估为:

$$
\begin{equation*}
\widehat{\text{MMD}}(P,Q)=\left \Vert\frac{1}{m}\sum_{x_i} \phi(x_i) - \frac{1}{n}\sum_{y_i} \phi(y_i)\right \Vert
\end{equation*}
$$

核函数的推导

将MMD的经验评估表达式两边平方得到:
$$
\begin{align*}
\widehat{\text{MMD}}(P,Q)^2 &= \left \Vert\frac{1}{m}\sum_{x_i} \phi(x_i) - \frac{1}{n}\sum_{y_i} \phi(y_i)\right \Vert^2 \\
&= \Vert\frac{1}{m}\sum_{x_i} \phi(x_i)\Vert^2 + \Vert\frac{1}{n}\sum_{y_i} \phi(y_i)\Vert^2 - 2\Vert\frac{1}{m}\sum_{x_i} \phi(x_i) \frac{1}{n}\sum_{y_i} \phi(y_i)\Vert
\end{align*}
$$
其中
$$
\begin{align*}
\Vert\frac{1}{m}\sum_{x_i} \phi(x_i)\Vert^2 &= \frac{1}{m^2} (\phi(x_1)+\phi(x_2)+\cdots+\phi(x_m))^T(\phi(x_1)+\phi(x_2)+\cdots+\phi(x_m)) \\
&=\frac{1}{m^2}\{\phi(x_1)^T\phi(x_1) + \cdots + \phi(x_1)^T\phi(x_m) +\phi(x_2)^T\phi(x_1) + \cdots + \phi(x_2)^T\phi(x_m) \cdots+\phi(x_m)^T\phi(x_1) + \cdots + \phi(x_m)^T\phi(x_m)\} \\
&=\frac{1}{m^2}\{k(x_1,x_1)+k(x_1,x_2)+\cdots +k(x_1,x_m) + k(x_2,x_1)+\cdots+k(x_2,x_m) +\cdots\} \\
&=\frac{1}{m^2}\sum_{i,j} k(x_i,x_j)
\end{align*}
$$
同理可得,
$$
\begin{equation*}
\Vert\frac{1}{n}\sum_{y_i} \phi(y_i)\Vert^2=\frac{1}{n^2}\sum_{i,j} k(y_i,y_j)
\end{equation*}
$$
以及
$$
\begin{equation*}
\Vert\frac{1}{m}\sum_{x_i} \phi(x_i) \frac{1}{n}\sum_{y_i} \phi(y_i)\Vert=\frac{1}{mn}\sum_{i,j}k(x_i,y_j)
\end{equation*}
$$

利用核函数,MMD的经验评估表达式可以被写成如下形式
$$
\begin{equation*}
\widehat{\text{MMD}}^2=\frac{1}{m^2}\sum_{i,j} k(x_i,x_j)+\frac{1}{n^2}\sum_{i,j} k(y_i,y_j)-\frac{2}{mn}\sum_{i,j}k(x_i,y_j)
\end{equation*}
$$

代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch


def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
"""
Convert source domain data and target domain data into kernel matrices, namely "k"
Params:
source: source domain data (n * len(x))
target: target domain data(m * len(y))
kernel_mul: kernel MMD,a cardinality that extends from bandwidth to both sides,
such as bandwidth/kernel_mul, bandwidth, bandwidth*kernel_mul.
kernel_num: the number of different Gaussian kernels
fix_sigma: Is it fixed? If so, it is a single core MMD.
If it is None, it defaults to the sigma values of different Gaussian kernels.
Return:
sum(kernel_val): sum of multiple kernel matrices
"""
# To calculate the number of rows of the matrix, the scale of source and target is generally the same.
n_samples = int(source.size()[0]) + int(target.size()[0])
total = torch.cat([source, target], dim=0) # Merge source and target in column direction
# Copy total (n+m) copies
total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
# Copy each row of total into (n+m) rows, that is, expand each data into (n+m) copies
total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
# Find the sum between any two data points, and the coordinates (i, j) in the obtained matrix
# represent the l2 distance between the i-th and j-th rows of data in total (0 when i=j).
L2_distance = ((total0 - total1) ** 2).sum(2)
# Adjust the sigma value of Gaussian kernel function
if fix_sigma:
bandwidth = fix_sigma
else:
bandwidth = torch.sum(L2_distance.data) / (n_samples ** 2 - n_samples)
# Using fix_sigma as the median and kernel_mul as a multiple, take kernel_num bandwidth values
# for example, fix_sigma = 1,we have [0.25,0.5,1,2,4]
bandwidth /= kernel_mul ** (kernel_num // 2)
bandwidth_list = [bandwidth * (kernel_mul ** i) for i in range(kernel_num)]
# Mathematical expression of Gaussian kernel function
kernel_val = [torch.exp(-L2_distance / (bandwidth_temp + 1e-8)) for bandwidth_temp in bandwidth_list]
# Obtain the final kernel matrix
return sum(kernel_val)


def get_MMD(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
"""
Calculate the MMD distance between source domain data and target domain data
Params:
source: source domain data (n * len(x))
target: target domain data(m * len(y))
kernel_mul: kernel MMD,a cardinality that extends from bandwidth to both sides,
such as bandwidth/kernel_mul, bandwidth, bandwidth*kernel_mul.
kernel_num: the number of different Gaussian kernels
fix_sigma: Is it fixed? If so, it is a single core MMD.
If it is None, it defaults to the sigma values of different Gaussian kernels.
Return:
loss: MMD loss
"""
batch_size = int(source.size()[0]) # The batchsize of source and target domains is generally the same by default
kernels = guassian_kernel(source, target,
kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)
# Divide the kernel matrix into 4 parts
XX = torch.mean(kernels[:batch_size, :batch_size])
YY = torch.mean(kernels[batch_size:, batch_size:])
XY = torch.mean(kernels[:batch_size, batch_size:])
YX = torch.mean(kernels[batch_size:, :batch_size])
loss = torch.mean(XX + YY - XY - YX)
return loss

参考资料