# -*- coding: utf-8 -*-
#
# Copyright 2011-2012 Canonical Ltd.
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU General Public License version 3, as published
# by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranties of
# MERCHANTABILITY, SATISFACTORY QUALITY, or FITNESS FOR A PARTICULAR
# PURPOSE.  See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# In addition, as a special exception, the copyright holders give
# permission to link the code of portions of this program with the
# OpenSSL library under certain conditions as described in each
# individual source file, and distribute linked combinations
# including the two.
# You must obey the GNU General Public License in all respects
# for all of the code used other than OpenSSL.  If you modify
# file(s) with this exception, you may extend this exception to your
# version of the file(s), but you are not obligated to do so.  If you
# do not wish to do so, delete this exception statement from your
# version.  If you delete this exception statement from all source
# files in the program, then also delete it here.
"""Windows specific tests for the main module."""

from twisted.internet import defer

from ubuntu_sso.main import windows
from ubuntu_sso.tests import TestCase

# because we are using twisted we have java like names C0103
# pylint: disable=C0103


class MockWin32APIs(object):
    """Some mock win32apis."""

    process_handle = object()
    TOKEN_ALL_ACCESS = object()
    TokenUser = object()

    def __init__(self, sample_token):
        """Initialize this mock instance."""
        self.sample_token = sample_token
        self.token_handle = object()

    def GetCurrentProcess(self):
        """Returns a fake process_handle."""
        return self.process_handle

    def OpenProcessToken(self, process_handle, access):
        """Open the process token."""
        assert process_handle is self.process_handle
        assert access is self.TOKEN_ALL_ACCESS
        return self.token_handle

    def GetTokenInformation(self, token_handle, info):
        """Get the information for this token."""
        assert token_handle == self.token_handle
        assert info == self.TokenUser
        return (self.sample_token, 0)

# pylint: enable=C0103


class MiscTestCase(TestCase):
    """Tests for module level misc functions."""

    def test_get_userid(self):
        """The returned user id is parsed ok."""
        expected_id = 1001
        sample_token = "abc-123-1001"

        win32apis = MockWin32APIs(sample_token)
        self.patch(windows, "win32process", win32apis)
        self.patch(windows, "win32security", win32apis)

        userid = windows.get_user_id()
        self.assertEqual(userid, expected_id)

    def _test_port_assignation(self, uid, expected_port):
        """Test a given uid/expected port combo."""
        self.patch(windows, "get_user_id", lambda: uid)
        self.assertEqual(windows.get_sso_pb_port(), expected_port)

    def test_get_sso_pb_port(self):
        """Test the get_sso_pb_port function."""
        uid = 1001
        uid_modulo = uid % windows.SSO_RESERVED_PORTS
        expected_port = (windows.SSO_BASE_PB_PORT +
                         uid_modulo * windows.SSO_PORT_ALLOCATION_STEP)
        self._test_port_assignation(uid, expected_port)

    def test_get_sso_pb_port_alt(self):
        """Test the get_sso_pb_port function."""
        uid = 2011 + windows.SSO_RESERVED_PORTS
        uid_modulo = uid % windows.SSO_RESERVED_PORTS
        expected_port = (windows.SSO_BASE_PB_PORT +
                         uid_modulo * windows.SSO_PORT_ALLOCATION_STEP)
        self._test_port_assignation(uid, expected_port)


class DescriptionFactoryTestcase(TestCase):
    """Test the factory."""

    @defer.inlineCallbacks
    def setUp(self):
        """Set the tests."""
        yield super(DescriptionFactoryTestcase, self).setUp()
        self.port = 55555
        self.patch(windows, 'get_sso_pb_port', lambda: self.port)

    def test_server_description(self):
        """Test getting the description info."""
        expected = windows.DescriptionFactory.server_description_pattern % \
                       self.port
        factory = windows.DescriptionFactory()
        self.assertEqual(expected, factory.server)

    def test_client_description(self):
        """Test getting the description info."""
        expected = windows.DescriptionFactory.client_description_pattern % \
                       self.port
        factory = windows.DescriptionFactory()
        self.assertEqual(expected, factory.client)
