Shortcuts

mmcv.ops.focal_loss 源代码

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable

from ..utils import ext_loader

ext_module = ext_loader.load_ext('_ext', [
    'sigmoid_focal_loss_forward', 'sigmoid_focal_loss_backward',
    'softmax_focal_loss_forward', 'softmax_focal_loss_backward'
])


class SigmoidFocalLossFunction(Function):

    @staticmethod
    def symbolic(g, input, target, gamma, alpha, weight, reduction):
        return g.op(
            'mmcv::MMCVSigmoidFocalLoss',
            input,
            target,
            gamma_f=gamma,
            alpha_f=alpha,
            weight_f=weight,
            reduction_s=reduction)

    @staticmethod
    def forward(ctx,
                input,
                target,
                gamma=2.0,
                alpha=0.25,
                weight=None,
                reduction='mean'):

        assert isinstance(
            target, (torch.Tensor, torch.LongTensor, torch.cuda.LongTensor))
        assert input.dim() == 2
        assert target.dim() == 1
        assert input.size(0) == target.size(0)
        if weight is None:
            weight = input.new_empty(0)
        else:
            assert weight.dim() == 1
            assert input.size(1) == weight.size(0)
        ctx.reduction_dict = {'none': 0, 'mean': 1, 'sum': 2}
        assert reduction in ctx.reduction_dict.keys()

        ctx.gamma = float(gamma)
        ctx.alpha = float(alpha)
        ctx.reduction = ctx.reduction_dict[reduction]

        output = input.new_zeros(input.size())

        ext_module.sigmoid_focal_loss_forward(
            input, target, weight, output, gamma=ctx.gamma, alpha=ctx.alpha)
        if ctx.reduction == ctx.reduction_dict['mean']:
            output = output.sum() / input.size(0)
        elif ctx.reduction == ctx.reduction_dict['sum']:
            output = output.sum()
        ctx.save_for_backward(input, target, weight)
        return output

    @staticmethod
    @once_differentiable
    def backward(ctx, grad_output):
        input, target, weight = ctx.saved_tensors

        grad_input = input.new_zeros(input.size())

        ext_module.sigmoid_focal_loss_backward(
            input,
            target,
            weight,
            grad_input,
            gamma=ctx.gamma,
            alpha=ctx.alpha)

        grad_input *= grad_output
        if ctx.reduction == ctx.reduction_dict['mean']:
            grad_input /= input.size(0)
        return grad_input, None, None, None, None, None


sigmoid_focal_loss = SigmoidFocalLossFunction.apply


[文档]class SigmoidFocalLoss(nn.Module): def __init__(self, gamma, alpha, weight=None, reduction='mean'): super(SigmoidFocalLoss, self).__init__() self.gamma = gamma self.alpha = alpha self.register_buffer('weight', weight) self.reduction = reduction
[文档] def forward(self, input, target): return sigmoid_focal_loss(input, target, self.gamma, self.alpha, self.weight, self.reduction)
def __repr__(self): s = self.__class__.__name__ s += f'(gamma={self.gamma}, ' s += f'alpha={self.alpha}, ' s += f'reduction={self.reduction})' return s
class SoftmaxFocalLossFunction(Function): @staticmethod def symbolic(g, input, target, gamma, alpha, weight, reduction): return g.op( 'mmcv::MMCVSoftmaxFocalLoss', input, target, gamma_f=gamma, alpha_f=alpha, weight_f=weight, reduction_s=reduction) @staticmethod def forward(ctx, input, target, gamma=2.0, alpha=0.25, weight=None, reduction='mean'): assert isinstance(target, (torch.LongTensor, torch.cuda.LongTensor)) assert input.dim() == 2 assert target.dim() == 1 assert input.size(0) == target.size(0) if weight is None: weight = input.new_empty(0) else: assert weight.dim() == 1 assert input.size(1) == weight.size(0) ctx.reduction_dict = {'none': 0, 'mean': 1, 'sum': 2} assert reduction in ctx.reduction_dict.keys() ctx.gamma = float(gamma) ctx.alpha = float(alpha) ctx.reduction = ctx.reduction_dict[reduction] channel_stats, _ = torch.max(input, dim=1) input_softmax = input - channel_stats.unsqueeze(1).expand_as(input) input_softmax.exp_() channel_stats = input_softmax.sum(dim=1) input_softmax /= channel_stats.unsqueeze(1).expand_as(input) output = input.new_zeros(input.size(0)) ext_module.softmax_focal_loss_forward( input_softmax, target, weight, output, gamma=ctx.gamma, alpha=ctx.alpha) if ctx.reduction == ctx.reduction_dict['mean']: output = output.sum() / input.size(0) elif ctx.reduction == ctx.reduction_dict['sum']: output = output.sum() ctx.save_for_backward(input_softmax, target, weight) return output @staticmethod def backward(ctx, grad_output): input_softmax, target, weight = ctx.saved_tensors buff = input_softmax.new_zeros(input_softmax.size(0)) grad_input = input_softmax.new_zeros(input_softmax.size()) ext_module.softmax_focal_loss_backward( input_softmax, target, weight, buff, grad_input, gamma=ctx.gamma, alpha=ctx.alpha) grad_input *= grad_output if ctx.reduction == ctx.reduction_dict['mean']: grad_input /= input_softmax.size(0) return grad_input, None, None, None, None, None softmax_focal_loss = SoftmaxFocalLossFunction.apply
[文档]class SoftmaxFocalLoss(nn.Module): def __init__(self, gamma, alpha, weight=None, reduction='mean'): super(SoftmaxFocalLoss, self).__init__() self.gamma = gamma self.alpha = alpha self.register_buffer('weight', weight) self.reduction = reduction
[文档] def forward(self, input, target): return softmax_focal_loss(input, target, self.gamma, self.alpha, self.weight, self.reduction)
def __repr__(self): s = self.__class__.__name__ s += f'(gamma={self.gamma}, ' s += f'alpha={self.alpha}, ' s += f'reduction={self.reduction})' return s
Read the Docs v: v1.5.1
Versions
latest
stable
v1.5.1
v1.5.0
v1.4.8
v1.4.7
v1.4.6
v1.4.5
v1.4.4
v1.4.3
v1.4.2
v1.4.1
v1.4.0
v1.3.18
v1.3.17
v1.3.16
v1.3.15
v1.3.14
v1.3.13
Downloads
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.