# Copyright 2013 Canonical Ltd.  This software is licensed under the
# GNU Affero General Public License version 3 (see the file LICENSE).

"""Tests for dhcp/detect.py"""

from __future__ import (
    absolute_import,
    print_function,
    unicode_literals,
    )

str = None

__metaclass__ = type
__all__ = []

from errno import EADDRNOTAVAIL
import fcntl
import os
from random import randint
import socket
import string

from maastest.detect_dhcp import (
    BOOTP_CLIENT_PORT,
    BOOTP_SERVER_PORT,
    DHCPDiscoverPacket,
    DHCPOfferPacket,
    get_interface_IP,
    get_interface_MAC,
    make_transaction_ID,
    probe_dhcp,
    receive_offers,
    request_dhcp,
    udp_socket,
    )
import maastest.detect_dhcp as detect_module
import mock
from testtools import TestCase


def make_MAC():
    """Return an arbitrary MAC address."""
    return ':'.join('%x' % randint(0, 255) for _ in range(6))


def make_IP():
    """Return an arbitrary IP address."""
    return '.'.join('%d' % randint(1, 254) for _ in range(4))


def pick_item(items):
    """Return an arbitrary item from container`."""
    index = randint(0, len(items) - 1)
    return items[index]


def make_name(prefix=None, length=10, sep='_'):
    """Return an arbitrary identifier-style string."""
    characters = string.letters + string.digits
    if prefix is None:
        name = pick_item(string.letters)
    else:
        name = prefix + sep
    while len(name) < length:
        name += pick_item(characters)
    return name


def make_bytes(length=None):
    """Return an arbitrary `bytes`."""
    if length is None:
        length = randint(1, 50)
    return os.urandom(length)


class TestMakeTransactionID(TestCase):
    """Tests for `make_transaction_ID`."""

    def test_produces_well_formed_ID(self):
        # The dhcp transaction should be 4 bytes long.
        transaction_id = make_transaction_ID()
        self.assertIsInstance(transaction_id, bytes)
        self.assertEqual(4, len(transaction_id))

    def test_randomises(self):
        self.assertNotEqual(
            make_transaction_ID(),
            make_transaction_ID())


class TestDHCPDiscoverPacket(TestCase):

    def test_init_sets_transaction_ID(self):
        transaction_id = make_transaction_ID()
        self.patch(
            detect_module, 'make_transaction_ID',
            mock.MagicMock(return_value=transaction_id))

        discover = DHCPDiscoverPacket(make_MAC())

        self.assertEqual(transaction_id, discover.transaction_ID)

    def test_init_sets_packed_mac(self):
        mac = make_MAC()
        discover = DHCPDiscoverPacket(mac)
        self.assertEqual(
            discover.string_mac_to_packed(mac),
            discover.packed_mac)

    def test_init_sets_packet(self):
        discover = DHCPDiscoverPacket(make_MAC())
        self.assertIsNotNone(discover.packet)

    def test_string_mac_to_packed(self):
        discover = DHCPDiscoverPacket
        expected = b"\x01\x22\x33\x99\xaa\xff"
        input = "01:22:33:99:aa:ff"
        self.assertEqual(expected, discover.string_mac_to_packed(input))

    def test__build(self):
        mac = make_MAC()
        discover = DHCPDiscoverPacket(mac)
        discover._build()

        expected = (
            b'\x01\x01\x06\x00' + discover.transaction_ID +
            b'\x00\x00\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00' +
            b'\x00\x00\x00\x00\x00\x00\x00\x00' +
            discover.packed_mac +
            b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00' +
            b'\x00' * 67 +
            b'\x00' * 125 +
            b'\x63\x82\x53\x63\x35\x01\x01\x3d\x06' + discover.packed_mac +
            b'\x37\x03\x03\x01\x06\xff')

        self.assertEqual(expected, discover.packet)


class TestDHCPOfferPacket(TestCase):

    def test_decodes_dhcp_server(self):
        buffer = b'\x00' * 245 + b'\x10\x00\x00\xaa'
        offer = DHCPOfferPacket(buffer)
        self.assertEqual('16.0.0.170', offer.dhcp_server_ID)


class TestGetInterfaceMAC(TestCase):
    """Tests for `get_interface_MAC`."""

    def test_loopback_has_zero_MAC(self):
        # It's a lame test, but what other network interfaces can we reliably
        # test this on?
        sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        self.assertEqual('00:00:00:00:00:00', get_interface_MAC(sock, 'lo'))


class TestGetInterfaceIP(TestCase):
    """Tests for `get_interface_IP`."""

    def test_loopback_has_localhost_address(self):
        sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        self.assertEqual('127.0.0.1', get_interface_IP(sock, 'lo'))

    def test_returns_None_if_no_address(self):
        failure = IOError(EADDRNOTAVAIL, "Interface has no address")
        self.patch(fcntl, 'ioctl', mock.MagicMock(side_effect=failure))
        sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        self.assertIsNone(get_interface_IP(sock, make_name('itf')))


def patch_socket(testcase):
    """Patch `socket.socket` to return a mock."""
    sock = mock.MagicMock()
    testcase.patch(socket, 'socket', mock.MagicMock(return_value=sock))
    return sock


class TestUDPSocket(TestCase):
    """Tests for `udp_socket`."""

    def test_yields_open_socket(self):
        patch_socket(self)
        with udp_socket() as sock:
            socket_calls = list(socket.socket.mock_calls)
            close_calls = list(sock.close.mock_calls)
        self.assertEqual(
            [mock.call(socket.AF_INET, socket.SOCK_DGRAM)],
            socket_calls)
        self.assertEqual([], close_calls)

    def test_closes_socket_on_exit(self):
        patch_socket(self)
        with udp_socket() as sock:
            pass
        self.assertEqual([mock.call()], sock.close.mock_calls)

    def test_sets_reuseaddr(self):
        patch_socket(self)
        with udp_socket() as sock:
            pass
        self.assertEqual(
            [mock.call(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)],
            sock.setsockopt.mock_calls)


class TestRequestDHCP(TestCase):
    """Tests for `request_dhcp`."""

    def patch_interface_MAC(self):
        """Patch `get_interface_MAC` to return a fixed value."""
        mac = make_MAC()
        self.patch(
            detect_module, 'get_interface_MAC',
            mock.MagicMock(return_value=mac))
        return mac

    def patch_interface_IP(self):
        """Patch `get_interface_IP` to return a fixed value."""
        ip = make_IP()
        self.patch(
            detect_module, 'get_interface_IP',
            mock.MagicMock(return_value=ip))
        return ip

    def patch_transaction_ID(self):
        """Patch `make_transaction_ID` to return a fixed value."""
        transaction_id = make_transaction_ID()
        self.patch(
            detect_module, 'make_transaction_ID',
            mock.MagicMock(return_value=transaction_id))
        return transaction_id

    def test_returns_None_if_no_IP_address(self):
        self.patch(
            detect_module, 'get_interface_MAC',
            mock.MagicMock(return_value=make_MAC()))
        self.patch(
            detect_module, 'get_interface_IP',
            mock.MagicMock(return_value=None))
        self.assertIsNone(request_dhcp(make_name('itf')))

    def test_sends_discover_packet(self):
        sock = patch_socket(self)
        self.patch_interface_MAC()
        self.patch_interface_IP()

        request_dhcp(make_name('itf'))

        [call] = sock.sendto.mock_calls
        _, args, _ = call
        self.assertEqual(
            ('<broadcast>', BOOTP_SERVER_PORT),
            args[1])

    def test_returns_transaction_id(self):
        patch_socket(self)
        self.patch_interface_MAC()
        self.patch_interface_IP()
        transaction_id = self.patch_transaction_ID()
        interface = make_name('itf')

        self.assertEqual(transaction_id, request_dhcp(interface))


class FakePacketReceiver:
    """Fake callable to substitute for a socket's `recv`.

    Returns the given packets on successive calls.  When it runs out,
    raises a timeout.
    """

    def __init__(self, packets=None):
        if packets is None:
            packets = []
        self.calls = []
        self.packets = list(packets)

    def __call__(self, recv_size):
        self.calls.append(recv_size)
        if len(self.packets) == 0:
            raise socket.timeout()
        else:
            return self.packets.pop(0)


def patch_recv(testcase, sock, num_packets=0):
    """Patch up socket's `recv` to return `num_packets` arbitrary packets.

    After that, further calls to `recv` will raise a timeout.
    """
    packets = [make_bytes() for _ in range(num_packets)]
    receiver = FakePacketReceiver(packets)
    testcase.patch(sock, 'recv', receiver)
    return receiver


def patch_offer_packet(testcase):
    """Patch a mock `DHCPOfferPacket`."""
    transaction_id = make_bytes(4)
    packet = mock.MagicMock()
    packet.transaction_ID = transaction_id
    packet.dhcp_server_ID = make_IP()
    testcase.patch(
        detect_module, 'DHCPOfferPacket',
        mock.MagicMock(return_value=packet))
    return packet


class TestReceiveOffers(TestCase):
    """Tests for `receive_offers`."""

    def test_receives_from_socket(self):
        sock = patch_socket(self)
        receiver = patch_recv(self, sock)
        transaction_id = patch_offer_packet(self).transaction_ID

        receive_offers(transaction_id)

        self.assertEqual(
            [mock.call(socket.AF_INET, socket.SOCK_DGRAM)],
            socket.socket.mock_calls)
        self.assertEqual(
            [mock.call(('', BOOTP_CLIENT_PORT))],
            sock.bind.mock_calls)
        self.assertEqual([1024], receiver.calls)

    def test_returns_empty_if_nothing_received(self):
        sock = patch_socket(self)
        patch_recv(self, sock)
        transaction_id = patch_offer_packet(self).transaction_ID

        self.assertEqual(set(), receive_offers(transaction_id))

    def test_processes_offer(self):
        sock = patch_socket(self)
        patch_recv(self, sock, 1)
        packet = patch_offer_packet(self)

        self.assertEqual(
            {packet.dhcp_server_ID},
            receive_offers(packet.transaction_ID))

    def test_ignores_other_transactions(self):
        sock = patch_socket(self)
        patch_recv(self, sock, 1)
        patch_offer_packet(self)
        other_transaction_id = make_bytes(4)

        self.assertEqual(set(), receive_offers(other_transaction_id))

    def test_propagates_errors_other_than_timeout(self):
        class InducedError(Exception):
            """Deliberately induced error for testing."""

        sock = patch_socket(self)
        sock.recv = mock.MagicMock(side_effect=InducedError)

        self.assertRaises(
            InducedError,
            receive_offers, make_bytes(4))


class TestProbeDHCP(TestCase):

    def test_detects_dhcp_servers(self):
        sock = patch_socket(self)
        patch_recv(self, sock, 1)
        packet = patch_offer_packet(self)
        transaction_id = packet.transaction_ID
        server = packet.dhcp_server_ID
        self.patch(
            detect_module, 'request_dhcp',
            mock.MagicMock(return_value=transaction_id))

        servers = probe_dhcp(make_name('itf'))

        self.assertEqual({server}, servers)

    def test_tolerates_interface_without_address(self):
        self.patch(
            detect_module, 'get_interface_MAC',
            mock.MagicMock(return_value=make_MAC()))
        self.patch(
            detect_module, 'get_interface_IP',
            mock.MagicMock(return_value=None))
        servers = probe_dhcp(make_name('itf'))
        self.assertEqual(set(), servers)
