# Copyright 2019 Baidu Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

"""C&W2 attack for evaluating model robustness."""

import warnings
import logging
import numpy as np
from tqdm import tqdm
from abc import ABC
from abc import abstractmethod
from .base import Metric
from .base import call_decorator
from perceptron.utils.image import onehot_like
from perceptron.utils.func import to_tanh_space
from perceptron.utils.func import to_model_space
from perceptron.utils.func import AdamOptimizer

class CarliniWagnerMetric(Metric, ABC):
    """The base class of the Carlini & Wagner attack.

    This attack is described in [1]_. This implementation
    is based on the reference implementation by Carlini [2]_.
    For bounds ≠ (0, 1), it differs from [2]_ because we
    normalize the squared L2 loss with the bounds.

    .. [1] Nicholas Carlini, David Wagner: "Towards Evaluating the
           Robustness of Neural Networks",
    .. [2]


    def __call__(self, adv, annotation=None, unpack=True,
                 binary_search_steps=5, max_iterations=1000,
                 confidence=0, learning_rate=5e-3,
                 initial_const=1e-2, abort_early=True):
        """ The L2 version of the Carlini & Wagner attack.

        adv : :class:`Adversarial`
            An :class:`Adversarial` instance
        label : int
            The reference label of the original input.
        unpack : bool
            If true, returns the adversarial input, otherwise returns
            the Adversarial object.
        binary_search_steps : int
            The number of steps for the binary search used to find the
            optimal tradeoff-constant between distance and confidence.
        max_iterations : int
            The maxinum number of iterations. Largert values are more
            accurate; setting it too small will require a large learning
            rate and will produce poor results.
        confidence : int or float
            Confidence of adversarial examples: a higher value produces
            adversarials that are further away, but more strongly classified
            as adversarial.
        learning_rate : float
            The learning rate for the attack algorithm. Smaller values
            produce better results but take longer to converge.
        initial_const : float
            The initial tradeoff-constant to use to tune the relative
            importance of distance and confidenc. If `binary_search_steps`
            is large, the initial constant is not important.
        abort_early : bool
            If True, Adam will be aborted if the loss hasn't decreased for
            some time (a tenth of max_iterations).

        a = adv

        del adv
        del annotation
        del unpack

        if not a.has_gradient():
            logging.fatal('Applied gradient-based attack to model that '
                          'does not provide gradients.')

        min_, max_ = a.bounds()

        if a.model_task() == 'cls':
            loss_and_gradient = self.cls_loss_and_gradient
        elif a.model_task() == 'det':
            loss_and_gradient = self.det_loss_and_gradient
            raise ValueError('Model task not supported. Check that the'
                             ' task is either cls or det')
        # variables representing inputs in attack space will be
        # prefixed with att_

        att_original = to_tanh_space(a.original_image, min_, max_)

        # will be close but not identical to a.original_image
        reconstructed_original, _ = to_model_space(att_original, min_, max_)

        # the binary search finds the smallest const for which we
        # find an adversarial
        const = initial_const
        lower_bound = 0
        upper_bound = np.inf

        for binary_search_step in tqdm(range(binary_search_steps)):
            if binary_search_step == binary_search_steps - 1 and \
                    binary_search_steps >= 10:
                const = upper_bound

  'starting optimization with const = {}'.format(const))
            att_perturbation = np.zeros_like(att_original)

            # create a new optimizer to minimize the perturbation
            optimizer = AdamOptimizer(att_perturbation.shape)

            found_adv = False  # found adv with the current const
            loss_at_previous_check = np.inf

            for iteration in range(max_iterations):
                x, dxdp = to_model_space(
                    att_original + att_perturbation, min_, max_)

                loss, gradient, is_adv = loss_and_gradient(
                    const, a, x, dxdp, reconstructed_original,
                    confidence, min_, max_)

      'iter: {}; loss: {}; best overall distance: {}'.format(
                    iteration, loss, a.distance))

                att_perturbation += optimizer(gradient, learning_rate)

                if is_adv:
                    # this binary search step can be considered a success
                    # but optimization continues to minimize perturbation size
                    found_adv = True

                if abort_early and \
                        iteration % (np.ceil(max_iterations / 10)) == 0:
                    # after each tenth of the iterations, check progress
                    if not (loss <= .9999 * loss_at_previous_check):
                        break  # stop Adam if there has not been progress
                    loss_at_previous_check = loss

            if found_adv:
      'found adversarial with const = {}'.format(const))
                upper_bound = const
      'failed to find adversarial '
                             'with const = {}'.format(const))
                lower_bound = const

            if upper_bound == np.inf:
                 # exponential search
                const *= 10
                # binary search
                const = (lower_bound + upper_bound) / 2

    def lp_distance_and_grad(reference, other, span):
        """To be extended with different L_p norm."""
        raise NotImplementedError

    def det_loss_and_gradient(cls, const, a, x, dxdp,
                              reconstructed_original, confidence, min_, max_):
        """Returns the loss and the gradient of the loss w.r.t. x,
        assuming that logits = model(x).

        _, is_adv_loss, is_adv_loss_grad, is_adv = \

        targeted = a.target_class() is not None
        if targeted:
            c_minimize = a.target_class()
            raise NotImplementedError

        # is_adv is True as soon as the is_adv_loss goes below 0
        # but sometimes we want additional confidence

        is_adv_loss += confidence
        is_adv_loss = max(0, is_adv_loss)

        s = max_ - min_
        squared_lp_distance, squared_lp_distance_grad = \
            cls.lp_distance_and_grad(reconstructed_original, x, s)

        total_loss = squared_lp_distance + const * is_adv_loss
        total_loss_grad = squared_lp_distance_grad + const * is_adv_loss_grad

        # backprop the gradient of the loss w.r.t. x further
        # to get the gradient of the loss w.r.t. att_perturbation
        assert total_loss_grad.shape == x.shape
        assert dxdp.shape == x.shape
        # we can do a simple elementwise multiplication, because
        # grad_x_wrt_p is a matrix of elementwise derivatives
        # (i.e. each x[i] w.r.t. p[i] only, for all i) and
        # grad_loss_wrt_x is a real gradient reshaped as a matrix
        gradient = total_loss_grad * dxdp

        return total_loss, gradient, is_adv

    def cls_loss_and_gradient(cls, const, a, x, dxdp,
                              reconstructed_original, confidence, min_, max_):
        """Returns the loss and the gradient of the loss w.r.t. x,
        assuming that logits = model(x).

        logits, is_adv = a.predictions(x)

        targeted = a.target_class() is not None
        if targeted:
            c_minimize = cls.best_other_class(logits, a.target_class())
            c_maximize = a.target_class()
            c_minimize = a.original_pred
            c_maximize = cls.best_other_class(logits, a.original_pred)

        is_adv_loss = logits[c_minimize] - logits[c_maximize]

        # is_adv is True as soon as the is_adv_loss goes below 0
        # but sometimes we want additional confidence

        is_adv_loss += confidence
        is_adv_loss = max(0, is_adv_loss)

        s = max_ - min_
        lp_distance, lp_distance_grad = \
            cls.lp_distance_and_grad(reconstructed_original, x, s)
        total_loss = lp_distance + const * is_adv_loss

        # calculate the gradient of total_loss w.r.t. x
        logits_diff_grad = np.zeros_like(logits)
        logits_diff_grad[c_minimize] = 1
        logits_diff_grad[c_maximize] = -1
        is_adv_loss_grad = a.backward(logits_diff_grad, x)
        assert is_adv_loss >= 0
        if is_adv_loss == 0:
            is_adv_loss_grad = 0

        total_loss_grad = lp_distance_grad + const * is_adv_loss_grad
        # backprop the gradient of the loss w.r.t. x further
        # to get the gradient of the loss w.r.t. att_perturbation
        assert total_loss_grad.shape == x.shape
        assert dxdp.shape == x.shape
        # we can do a simple elementwise multiplication, because
        # grad_x_wrt_p is a matrix of elementwise derivatives
        # (i.e. each x[i] w.r.t. p[i] only, for all i) and
        # grad_loss_wrt_x is a real gradient reshaped as a matrix
        gradient = total_loss_grad * dxdp

        return total_loss, gradient, is_adv

    def best_other_class(logits, exclude):
        """Returns the index of the largest logit, ignoring the class that
        is passed as `exclude`.
        other_logits = logits - onehot_like(logits, exclude, value=np.inf)
        return np.argmax(other_logits)

[docs]class CarliniWagnerL2Metric(CarliniWagnerMetric): """The L2 version of C&W attack."""
[docs] @staticmethod def lp_distance_and_grad(reference, other, span): """Calculate L2 distance and gradient.""" squared_l2_distance = np.sum( (other - reference) ** 2) / span ** 2 squared_l2_distance_grad = (2 / span ** 2) * (other - reference) return squared_l2_distance, squared_l2_distance_grad
[docs]class CarliniWagnerLinfMetric(CarliniWagnerMetric): """The L_inf version of C&W attack."""
[docs] @staticmethod def lp_distance_and_grad(reference, other, span): """Calculate L2 distance and gradient.""" diff = np.abs((other - reference)) max_diff = np.max(diff) l_inf_distance = max_diff / span if(max_diff == 0): l_inf_distance_grad = np.zeros_like(diff, dtype=np.float32) else: l_inf_distance_grad = (diff == max_diff).astype(np.float32) return l_inf_distance, l_inf_distance_grad