defguassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None): """ Convert source domain data and target domain data into kernel matrices, namely "k" Params: source: source domain data (n * len(x)) target: target domain data(m * len(y)) kernel_mul: kernel MMD,a cardinality that extends from bandwidth to both sides, such as bandwidth/kernel_mul, bandwidth, bandwidth*kernel_mul. kernel_num: the number of different Gaussian kernels fix_sigma: Is it fixed? If so, it is a single core MMD. If it is None, it defaults to the sigma values of different Gaussian kernels. Return: sum(kernel_val): sum of multiple kernel matrices """ # To calculate the number of rows of the matrix, the scale of source and target is generally the same. n_samples = int(source.size()[0]) + int(target.size()[0]) total = torch.cat([source, target], dim=0) # Merge source and target in column direction # Copy total (n+m) copies total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1))) # Copy each row of total into (n+m) rows, that is, expand each data into (n+m) copies total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1))) # Find the sum between any two data points, and the coordinates (i, j) in the obtained matrix # represent the l2 distance between the i-th and j-th rows of data in total (0 when i=j). L2_distance = ((total0 - total1) ** 2).sum(2) # Adjust the sigma value of Gaussian kernel function if fix_sigma: bandwidth = fix_sigma else: bandwidth = torch.sum(L2_distance.data) / (n_samples ** 2 - n_samples) # Using fix_sigma as the median and kernel_mul as a multiple, take kernel_num bandwidth values # for example, fix_sigma = 1,we have [0.25,0.5,1,2,4] bandwidth /= kernel_mul ** (kernel_num // 2) bandwidth_list = [bandwidth * (kernel_mul ** i) for i inrange(kernel_num)] # Mathematical expression of Gaussian kernel function kernel_val = [torch.exp(-L2_distance / (bandwidth_temp + 1e-8)) for bandwidth_temp in bandwidth_list] # Obtain the final kernel matrix returnsum(kernel_val)
defget_MMD(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None): """ Calculate the MMD distance between source domain data and target domain data Params: source: source domain data (n * len(x)) target: target domain data(m * len(y)) kernel_mul: kernel MMD,a cardinality that extends from bandwidth to both sides, such as bandwidth/kernel_mul, bandwidth, bandwidth*kernel_mul. kernel_num: the number of different Gaussian kernels fix_sigma: Is it fixed? If so, it is a single core MMD. If it is None, it defaults to the sigma values of different Gaussian kernels. Return: loss: MMD loss """ batch_size = int(source.size()[0]) # The batchsize of source and target domains is generally the same by default kernels = guassian_kernel(source, target, kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma) # Divide the kernel matrix into 4 parts XX = torch.mean(kernels[:batch_size, :batch_size]) YY = torch.mean(kernels[batch_size:, batch_size:]) XY = torch.mean(kernels[:batch_size, batch_size:]) YX = torch.mean(kernels[batch_size:, :batch_size]) loss = torch.mean(XX + YY - XY - YX) return loss