Speed up Barlow Twins loss (#2519)

This commit is contained in:
Adrià Arrufat 2022-02-25 12:42:50 +09:00 committed by GitHub
parent 50b78da53a
commit 1ccd03fec9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -4016,9 +4016,6 @@ namespace dlib
// Normalize both batches independently across the batch dimension
const double eps = 1e-4;
resizable_tensor za_norm, means_a, invstds_a;
resizable_tensor zb_norm, means_b, invstds_b;
resizable_tensor rms, rvs, g, b;
g.set_size(1, sample_size);
g = 1;
b.set_size(1, sample_size);
@ -4027,21 +4024,29 @@ namespace dlib
tt::batch_normalize(eps, zb_norm, means_b, invstds_b, 1, rms, rvs, zb, g, b);
// Compute the empirical cross-correlation matrix
resizable_tensor eccm;
eccm.set_size(sample_size, sample_size);
tt::gemm(0, eccm, 1, za_norm, true, zb_norm, false);
eccm /= batch_size;
// Compute the loss: MSE between eccm and the identity matrix.
// Off-diagonal terms are weighed by lambda.
const matrix<float> C = mat(eccm);
const double diagonal_loss = sum(squared(diag(C) - 1));
const double off_diag_loss = sum(squared(C - diagm(diag(C))));
double loss = diagonal_loss + lambda * off_diag_loss;
// Set sizes and setup auxiliary tensors
if (!have_same_dimensions(eccm, identity))
identity = identity_matrix<float>(sample_size);
if (!have_same_dimensions(eccm, cdiag))
cdiag.copy_size(eccm);
if (!have_same_dimensions(eccm, cdiag_1))
cdiag_1.copy_size(eccm);
if (!have_same_dimensions(eccm, off_mask))
off_mask = ones_matrix<float>(sample_size, sample_size) - identity_matrix<float>(sample_size);
if (!have_same_dimensions(eccm, off_diag))
off_diag.copy_size(eccm);
if (!have_same_dimensions(grad, grad_input))
grad_input.copy_size(grad);
if (!have_same_dimensions(g_grad, g))
g_grad.copy_size(g);
if (!have_same_dimensions(b_grad, b))
b_grad.copy_size(b);
// Loss gradient, which will be used as the input of the batch normalization gradient
resizable_tensor grad_input;
grad_input.copy_size(grad);
auto grad_input_a = split(grad_input);
auto grad_input_b = split(grad_input, offset);
@ -4051,11 +4056,15 @@ namespace dlib
// C = eccm
// D = off_mask: a mask that keeps only the elements outside the diagonal
// A diagonal matrix containing the diagonal of eccm
tt::multiply(false, cdiag, eccm, identity);
// The diagonal of eccm minus the identity matrix
tt::affine_transform(cdiag_1, cdiag, identity, 1, -1);
// diagonal term: sum((diag(A' * B) - vector(1)).^2)
// --------------------------------------------
// => d/dA = 2 * B * diag(diag(A' * B) - vector(1)) = 2 * B * diag(diag(C) - vector(1))
// => d/dB = 2 * A * diag(diag(A' * B) - vector(1)) = 2 * A * diag(diag(C) - vector(1))
resizable_tensor cdiag_1(diagm(diag(mat(eccm) - 1)));
tt::gemm(0, grad_input_a, 2, zb_norm, false, cdiag_1, false);
tt::gemm(0, grad_input_b, 2, za_norm, false, cdiag_1, false);
@ -4063,22 +4072,21 @@ namespace dlib
// --------------------------------
// => d/dA = 2 * B * ((B' * A) .* (D .* D)') = 2 * B * (C .* (D .* D)) = 2 * B * (C .* D)
// => d/dB = 2 * A * ((A' * B) .* (D .* D)) = 2 * A * (C .* (D .* D)) = 2 * A * (C .* D)
resizable_tensor off_mask(ones_matrix<float>(sample_size, sample_size) - identity_matrix<float>(sample_size));
resizable_tensor off_diag(sample_size, sample_size);
tt::multiply(false, off_diag, eccm, off_mask);
tt::gemm(1, grad_input_a, 2 * lambda, zb_norm, false, off_diag, false);
tt::gemm(1, grad_input_b, 2 * lambda, za_norm, false, off_diag, false);
// Compute the batch norm gradients, g and b grads are not used
resizable_tensor g_grad, b_grad;
g_grad.copy_size(g);
b_grad.copy_size(b);
auto gza = split(grad);
auto gzb = split(grad, offset);
tt::batch_normalize_gradient(eps, grad_input_a, means_a, invstds_a, za, g, gza, g_grad, b_grad);
tt::batch_normalize_gradient(eps, grad_input_b, means_b, invstds_b, zb, g, gzb, g_grad, b_grad);
return loss;
// Compute the loss: MSE between eccm and the identity matrix.
// Off-diagonal terms are weighed by lambda.
const double diagonal_loss = sum(squared(mat(cdiag_1)));
const double off_diag_loss = sum(squared(mat(off_diag)));
return diagonal_loss + lambda * off_diag_loss;
}
float get_lambda() const { return lambda; }
@ -4116,6 +4124,11 @@ namespace dlib
private:
float lambda = 0.0051;
mutable resizable_tensor za_norm, means_a, invstds_a;
mutable resizable_tensor zb_norm, means_b, invstds_b;
mutable resizable_tensor rms, rvs, g, b;
mutable resizable_tensor eccm, grad_input, g_grad, b_grad;
mutable resizable_tensor cdiag, cdiag_1, identity, off_mask, off_diag;
};
template <typename SUBNET>