#!/usr/bin/env python



"""Unit tests for the with statement specified in PEP 343."""





__author__ = "Mike Bland"

__email__ = "mbland at acm dot org"



import sys

import unittest

from collections import deque

from contextlib import GeneratorContextManager, contextmanager

from test.test_support import run_unittest





class MockContextManager(GeneratorContextManager):

    def __init__(self, gen):

        GeneratorContextManager.__init__(self, gen)

        self.enter_called = False

        self.exit_called = False

        self.exit_args = None



    def __enter__(self):

        self.enter_called = True

        return GeneratorContextManager.__enter__(self)



    def __exit__(self, type, value, traceback):

        self.exit_called = True

        self.exit_args = (type, value, traceback)

        return GeneratorContextManager.__exit__(self, type,

                                                value, traceback)





def mock_contextmanager(func):

    def helper(*args, **kwds):

        return MockContextManager(func(*args, **kwds))

    return helper





class MockResource(object):

    def __init__(self):

        self.yielded = False

        self.stopped = False





@mock_contextmanager

def mock_contextmanager_generator():

    mock = MockResource()

    try:

        mock.yielded = True

        yield mock

    finally:

        mock.stopped = True





class Nested(object):



    def __init__(self, *managers):

        self.managers = managers

        self.entered = None



    def __enter__(self):

        if self.entered is not None:

            raise RuntimeError("Context is not reentrant")

        self.entered = deque()

        vars = []

        try:

            for mgr in self.managers:

                vars.append(mgr.__enter__())

                self.entered.appendleft(mgr)

        except:

            if not self.__exit__(*sys.exc_info()):

                raise

        return vars



    def __exit__(self, *exc_info):

        # Behave like nested with statements

        # first in, last out

        # New exceptions override old ones

        ex = exc_info

        for mgr in self.entered:

            try:

                if mgr.__exit__(*ex):

                    ex = (None, None, None)

            except:

                ex = sys.exc_info()

        self.entered = None

        if ex is not exc_info:

            raise ex[0], ex[1], ex[2]





class MockNested(Nested):

    def __init__(self, *managers):

        Nested.__init__(self, *managers)

        self.enter_called = False

        self.exit_called = False

        self.exit_args = None



    def __enter__(self):

        self.enter_called = True

        return Nested.__enter__(self)



    def __exit__(self, *exc_info):

        self.exit_called = True

        self.exit_args = exc_info

        return Nested.__exit__(self, *exc_info)





class FailureTestCase(unittest.TestCase):

    def testNameError(self):

        def fooNotDeclared():

            with foo: pass

        self.assertRaises(NameError, fooNotDeclared)



    def testEnterAttributeError(self):

        class LacksEnter(object):

            def __exit__(self, type, value, traceback):

                pass



        def fooLacksEnter():

            foo = LacksEnter()

            with foo: pass

        self.assertRaises(AttributeError, fooLacksEnter)



    def testExitAttributeError(self):

        class LacksExit(object):

            def __enter__(self):

                pass



        def fooLacksExit():

            foo = LacksExit()

            with foo: pass

        self.assertRaises(AttributeError, fooLacksExit)



    def assertRaisesSyntaxError(self, codestr):

        def shouldRaiseSyntaxError(s):

            compile(s, '', 'single')

        self.assertRaises(SyntaxError, shouldRaiseSyntaxError, codestr)



    def testAssignmentToNoneError(self):

        self.assertRaisesSyntaxError('with mock as None:\n  pass')

        self.assertRaisesSyntaxError(

            'with mock as (None):\n'

            '  pass')



    def testAssignmentToEmptyTupleError(self):

        self.assertRaisesSyntaxError(

            'with mock as ():\n'

            '  pass')



    def testAssignmentToTupleOnlyContainingNoneError(self):

        self.assertRaisesSyntaxError('with mock as None,:\n  pass')

        self.assertRaisesSyntaxError(

            'with mock as (None,):\n'

            '  pass')



    def testAssignmentToTupleContainingNoneError(self):

        self.assertRaisesSyntaxError(

            'with mock as (foo, None, bar):\n'

            '  pass')



    def testEnterThrows(self):

        class EnterThrows(object):

            def __enter__(self):

                raise RuntimeError("Enter threw")

            def __exit__(self, *args):

                pass



        def shouldThrow():

            ct = EnterThrows()

            self.foo = None

            with ct as self.foo:

                pass

        self.assertRaises(RuntimeError, shouldThrow)

        self.assertEqual(self.foo, None)



    def testExitThrows(self):

        class ExitThrows(object):

            def __enter__(self):

                return

            def __exit__(self, *args):

                raise RuntimeError(42)

        def shouldThrow():

            with ExitThrows():

                pass

        self.assertRaises(RuntimeError, shouldThrow)



class ContextmanagerAssertionMixin(object):

    TEST_EXCEPTION = RuntimeError("test exception")



    def assertInWithManagerInvariants(self, mock_manager):

        self.assertTrue(mock_manager.enter_called)

        self.assertFalse(mock_manager.exit_called)

        self.assertEqual(mock_manager.exit_args, None)



    def assertAfterWithManagerInvariants(self, mock_manager, exit_args):

        self.assertTrue(mock_manager.enter_called)

        self.assertTrue(mock_manager.exit_called)

        self.assertEqual(mock_manager.exit_args, exit_args)



    def assertAfterWithManagerInvariantsNoError(self, mock_manager):

        self.assertAfterWithManagerInvariants(mock_manager,

            (None, None, None))



    def assertInWithGeneratorInvariants(self, mock_generator):

        self.assertTrue(mock_generator.yielded)

        self.assertFalse(mock_generator.stopped)



    def assertAfterWithGeneratorInvariantsNoError(self, mock_generator):

        self.assertTrue(mock_generator.yielded)

        self.assertTrue(mock_generator.stopped)



    def raiseTestException(self):

        raise self.TEST_EXCEPTION



    def assertAfterWithManagerInvariantsWithError(self, mock_manager):

        self.assertTrue(mock_manager.enter_called)

        self.assertTrue(mock_manager.exit_called)

        self.assertEqual(mock_manager.exit_args[0], RuntimeError)

        self.assertEqual(mock_manager.exit_args[1], self.TEST_EXCEPTION)



    def assertAfterWithGeneratorInvariantsWithError(self, mock_generator):

        self.assertTrue(mock_generator.yielded)

        self.assertTrue(mock_generator.stopped)





class NonexceptionalTestCase(unittest.TestCase, ContextmanagerAssertionMixin):

    def testInlineGeneratorSyntax(self):

        with mock_contextmanager_generator():

            pass



    def testUnboundGenerator(self):

        mock = mock_contextmanager_generator()

        with mock:

            pass

        self.assertAfterWithManagerInvariantsNoError(mock)



    def testInlineGeneratorBoundSyntax(self):

        with mock_contextmanager_generator() as foo:

            self.assertInWithGeneratorInvariants(foo)

        # FIXME: In the future, we'll try to keep the bound names from leaking

        self.assertAfterWithGeneratorInvariantsNoError(foo)



    def testInlineGeneratorBoundToExistingVariable(self):

        foo = None

        with mock_contextmanager_generator() as foo:

            self.assertInWithGeneratorInvariants(foo)

        self.assertAfterWithGeneratorInvariantsNoError(foo)



    def testInlineGeneratorBoundToDottedVariable(self):

        with mock_contextmanager_generator() as self.foo:

            self.assertInWithGeneratorInvariants(self.foo)

        self.assertAfterWithGeneratorInvariantsNoError(self.foo)



    def testBoundGenerator(self):

        mock = mock_contextmanager_generator()

        with mock as foo:

            self.assertInWithGeneratorInvariants(foo)

            self.assertInWithManagerInvariants(mock)

        self.assertAfterWithGeneratorInvariantsNoError(foo)

        self.assertAfterWithManagerInvariantsNoError(mock)



    def testNestedSingleStatements(self):

        mock_a = mock_contextmanager_generator()

        with mock_a as foo:

            mock_b = mock_contextmanager_generator()

            with mock_b as bar:

                self.assertInWithManagerInvariants(mock_a)

                self.assertInWithManagerInvariants(mock_b)

                self.assertInWithGeneratorInvariants(foo)

                self.assertInWithGeneratorInvariants(bar)

            self.assertAfterWithManagerInvariantsNoError(mock_b)

            self.assertAfterWithGeneratorInvariantsNoError(bar)

            self.assertInWithManagerInvariants(mock_a)

            self.assertInWithGeneratorInvariants(foo)

        self.assertAfterWithManagerInvariantsNoError(mock_a)

        self.assertAfterWithGeneratorInvariantsNoError(foo)





class NestedNonexceptionalTestCase(unittest.TestCase,

    ContextmanagerAssertionMixin):

    def testSingleArgInlineGeneratorSyntax(self):

        with Nested(mock_contextmanager_generator()):

            pass



    def testSingleArgUnbound(self):

        mock_contextmanager = mock_contextmanager_generator()

        mock_nested = MockNested(mock_contextmanager)

        with mock_nested:

            self.assertInWithManagerInvariants(mock_contextmanager)

            self.assertInWithManagerInvariants(mock_nested)

        self.assertAfterWithManagerInvariantsNoError(mock_contextmanager)

        self.assertAfterWithManagerInvariantsNoError(mock_nested)



    def testSingleArgBoundToNonTuple(self):

        m = mock_contextmanager_generator()

        # This will bind all the arguments to nested() into a single list

        # assigned to foo.

        with Nested(m) as foo:

            self.assertInWithManagerInvariants(m)

        self.assertAfterWithManagerInvariantsNoError(m)



    def testSingleArgBoundToSingleElementParenthesizedList(self):

        m = mock_contextmanager_generator()

        # This will bind all the arguments to nested() into a single list

        # assigned to foo.

        with Nested(m) as (foo):

            self.assertInWithManagerInvariants(m)

        self.assertAfterWithManagerInvariantsNoError(m)



    def testSingleArgBoundToMultipleElementTupleError(self):

        def shouldThrowValueError():

            with Nested(mock_contextmanager_generator()) as (foo, bar):

                pass

        self.assertRaises(ValueError, shouldThrowValueError)



    def testSingleArgUnbound(self):

        mock_contextmanager = mock_contextmanager_generator()

        mock_nested = MockNested(mock_contextmanager)

        with mock_nested:

            self.assertInWithManagerInvariants(mock_contextmanager)

            self.assertInWithManagerInvariants(mock_nested)

        self.assertAfterWithManagerInvariantsNoError(mock_contextmanager)

        self.assertAfterWithManagerInvariantsNoError(mock_nested)



    def testMultipleArgUnbound(self):

        m = mock_contextmanager_generator()

        n = mock_contextmanager_generator()

        o = mock_contextmanager_generator()

        mock_nested = MockNested(m, n, o)

        with mock_nested:

            self.assertInWithManagerInvariants(m)

            self.assertInWithManagerInvariants(n)

            self.assertInWithManagerInvariants(o)

            self.assertInWithManagerInvariants(mock_nested)

        self.assertAfterWithManagerInvariantsNoError(m)

        self.assertAfterWithManagerInvariantsNoError(n)

        self.assertAfterWithManagerInvariantsNoError(o)

        self.assertAfterWithManagerInvariantsNoError(mock_nested)



    def testMultipleArgBound(self):

        mock_nested = MockNested(mock_contextmanager_generator(),

            mock_contextmanager_generator(), mock_contextmanager_generator())

        with mock_nested as (m, n, o):

            self.assertInWithGeneratorInvariants(m)

            self.assertInWithGeneratorInvariants(n)

            self.assertInWithGeneratorInvariants(o)

            self.assertInWithManagerInvariants(mock_nested)

        self.assertAfterWithGeneratorInvariantsNoError(m)

        self.assertAfterWithGeneratorInvariantsNoError(n)

        self.assertAfterWithGeneratorInvariantsNoError(o)

        self.assertAfterWithManagerInvariantsNoError(mock_nested)





class ExceptionalTestCase(unittest.TestCase, ContextmanagerAssertionMixin):

    def testSingleResource(self):

        cm = mock_contextmanager_generator()

        def shouldThrow():

            with cm as self.resource:

                self.assertInWithManagerInvariants(cm)

                self.assertInWithGeneratorInvariants(self.resource)

                self.raiseTestException()

        self.assertRaises(RuntimeError, shouldThrow)

        self.assertAfterWithManagerInvariantsWithError(cm)

        self.assertAfterWithGeneratorInvariantsWithError(self.resource)



    def testNestedSingleStatements(self):

        mock_a = mock_contextmanager_generator()

        mock_b = mock_contextmanager_generator()

        def shouldThrow():

            with mock_a as self.foo:

                with mock_b as self.bar:

                    self.assertInWithManagerInvariants(mock_a)

                    self.assertInWithManagerInvariants(mock_b)

                    self.assertInWithGeneratorInvariants(self.foo)

                    self.assertInWithGeneratorInvariants(self.bar)

                    self.raiseTestException()

        self.assertRaises(RuntimeError, shouldThrow)

        self.assertAfterWithManagerInvariantsWithError(mock_a)

        self.assertAfterWithManagerInvariantsWithError(mock_b)

        self.assertAfterWithGeneratorInvariantsWithError(self.foo)

        self.assertAfterWithGeneratorInvariantsWithError(self.bar)



    def testMultipleResourcesInSingleStatement(self):

        cm_a = mock_contextmanager_generator()

        cm_b = mock_contextmanager_generator()

        mock_nested = MockNested(cm_a, cm_b)

        def shouldThrow():

            with mock_nested as (self.resource_a, self.resource_b):

                self.assertInWithManagerInvariants(cm_a)

                self.assertInWithManagerInvariants(cm_b)

                self.assertInWithManagerInvariants(mock_nested)

                self.assertInWithGeneratorInvariants(self.resource_a)

                self.assertInWithGeneratorInvariants(self.resource_b)

                self.raiseTestException()

        self.assertRaises(RuntimeError, shouldThrow)

        self.assertAfterWithManagerInvariantsWithError(cm_a)

        self.assertAfterWithManagerInvariantsWithError(cm_b)

        self.assertAfterWithManagerInvariantsWithError(mock_nested)

        self.assertAfterWithGeneratorInvariantsWithError(self.resource_a)

        self.assertAfterWithGeneratorInvariantsWithError(self.resource_b)



    def testNestedExceptionBeforeInnerStatement(self):

        mock_a = mock_contextmanager_generator()

        mock_b = mock_contextmanager_generator()

        self.bar = None

        def shouldThrow():

            with mock_a as self.foo:

                self.assertInWithManagerInvariants(mock_a)

                self.assertInWithGeneratorInvariants(self.foo)

                self.raiseTestException()

                with mock_b as self.bar:

                    pass

        self.assertRaises(RuntimeError, shouldThrow)

        self.assertAfterWithManagerInvariantsWithError(mock_a)

        self.assertAfterWithGeneratorInvariantsWithError(self.foo)



        # The inner statement stuff should never have been touched

        self.assertEqual(self.bar, None)

        self.assertFalse(mock_b.enter_called)

        self.assertFalse(mock_b.exit_called)

        self.assertEqual(mock_b.exit_args, None)



    def testNestedExceptionAfterInnerStatement(self):

        mock_a = mock_contextmanager_generator()

        mock_b = mock_contextmanager_generator()

        def shouldThrow():

            with mock_a as self.foo:

                with mock_b as self.bar:

                    self.assertInWithManagerInvariants(mock_a)

                    self.assertInWithManagerInvariants(mock_b)

                    self.assertInWithGeneratorInvariants(self.foo)

                    self.assertInWithGeneratorInvariants(self.bar)

                self.raiseTestException()

        self.assertRaises(RuntimeError, shouldThrow)

        self.assertAfterWithManagerInvariantsWithError(mock_a)

        self.assertAfterWithManagerInvariantsNoError(mock_b)

        self.assertAfterWithGeneratorInvariantsWithError(self.foo)

        self.assertAfterWithGeneratorInvariantsNoError(self.bar)



    def testRaisedStopIteration1(self):

        # From bug 1462485

        @contextmanager

        def cm():

            yield



        def shouldThrow():

            with cm():

                raise StopIteration("from with")



        self.assertRaises(StopIteration, shouldThrow)



    def testRaisedStopIteration2(self):

        # From bug 1462485

        class cm(object):

            def __enter__(self):

                pass

            def __exit__(self, type, value, traceback):

                pass



        def shouldThrow():

            with cm():

                raise StopIteration("from with")



        self.assertRaises(StopIteration, shouldThrow)



    def testRaisedStopIteration3(self):

        # Another variant where the exception hasn't been instantiated

        # From bug 1705170

        @contextmanager

        def cm():

            yield



        def shouldThrow():

            with cm():

                raise iter([]).next()



        self.assertRaises(StopIteration, shouldThrow)



    def testRaisedGeneratorExit1(self):

        # From bug 1462485

        @contextmanager

        def cm():

            yield



        def shouldThrow():

            with cm():

                raise GeneratorExit("from with")



        self.assertRaises(GeneratorExit, shouldThrow)



    def testRaisedGeneratorExit2(self):

        # From bug 1462485

        class cm (object):

            def __enter__(self):

                pass

            def __exit__(self, type, value, traceback):

                pass



        def shouldThrow():

            with cm():

                raise GeneratorExit("from with")



        self.assertRaises(GeneratorExit, shouldThrow)



    def testErrorsInBool(self):

        # issue4589: __exit__ return code may raise an exception

        # when looking at its truth value.



        class cm(object):

            def __init__(self, bool_conversion):

                class Bool:

                    def __nonzero__(self):

                        return bool_conversion()

                self.exit_result = Bool()

            def __enter__(self):

                return 3

            def __exit__(self, a, b, c):

                return self.exit_result



        def trueAsBool():

            with cm(lambda: True):

                self.fail("Should NOT see this")

        trueAsBool()



        def falseAsBool():

            with cm(lambda: False):

                self.fail("Should raise")

        self.assertRaises(AssertionError, falseAsBool)



        def failAsBool():

            with cm(lambda: 1//0):

                self.fail("Should NOT see this")

        self.assertRaises(ZeroDivisionError, failAsBool)





class NonLocalFlowControlTestCase(unittest.TestCase):



    def testWithBreak(self):

        counter = 0

        while True:

            counter += 1

            with mock_contextmanager_generator():

                counter += 10

                break

            counter += 100 # Not reached

        self.assertEqual(counter, 11)



    def testWithContinue(self):

        counter = 0

        while True:

            counter += 1

            if counter > 2:

                break

            with mock_contextmanager_generator():

                counter += 10

                continue

            counter += 100 # Not reached

        self.assertEqual(counter, 12)



    def testWithReturn(self):

        def foo():

            counter = 0

            while True:

                counter += 1

                with mock_contextmanager_generator():

                    counter += 10

                    return counter

                counter += 100 # Not reached

        self.assertEqual(foo(), 11)



    def testWithYield(self):

        def gen():

            with mock_contextmanager_generator():

                yield 12

                yield 13

        x = list(gen())

        self.assertEqual(x, [12, 13])



    def testWithRaise(self):

        counter = 0

        try:

            counter += 1

            with mock_contextmanager_generator():

                counter += 10

                raise RuntimeError

            counter += 100 # Not reached

        except RuntimeError:

            self.assertEqual(counter, 11)

        else:

            self.fail("Didn't raise RuntimeError")





class AssignmentTargetTestCase(unittest.TestCase):



    def testSingleComplexTarget(self):

        targets = {1: [0, 1, 2]}

        with mock_contextmanager_generator() as targets[1][0]:

            self.assertEqual(targets.keys(), [1])

            self.assertEqual(targets[1][0].__class__, MockResource)

        with mock_contextmanager_generator() as targets.values()[0][1]:

            self.assertEqual(targets.keys(), [1])

            self.assertEqual(targets[1][1].__class__, MockResource)

        with mock_contextmanager_generator() as targets[2]:

            keys = targets.keys()

            keys.sort()

            self.assertEqual(keys, [1, 2])

        class C: pass

        blah = C()

        with mock_contextmanager_generator() as blah.foo:

            self.assertEqual(hasattr(blah, "foo"), True)



    def testMultipleComplexTargets(self):

        class C:

            def __enter__(self): return 1, 2, 3

            def __exit__(self, t, v, tb): pass

        targets = {1: [0, 1, 2]}

        with C() as (targets[1][0], targets[1][1], targets[1][2]):

            self.assertEqual(targets, {1: [1, 2, 3]})

        with C() as (targets.values()[0][2], targets.values()[0][1], targets.values()[0][0]):

            self.assertEqual(targets, {1: [3, 2, 1]})

        with C() as (targets[1], targets[2], targets[3]):

            self.assertEqual(targets, {1: 1, 2: 2, 3: 3})

        class B: pass

        blah = B()

        with C() as (blah.one, blah.two, blah.three):

            self.assertEqual(blah.one, 1)

            self.assertEqual(blah.two, 2)

            self.assertEqual(blah.three, 3)





class ExitSwallowsExceptionTestCase(unittest.TestCase):



    def testExitTrueSwallowsException(self):

        class AfricanSwallow:

            def __enter__(self): pass

            def __exit__(self, t, v, tb): return True

        try:

            with AfricanSwallow():

                1/0

        except ZeroDivisionError:

            self.fail("ZeroDivisionError should have been swallowed")



    def testExitFalseDoesntSwallowException(self):

        class EuropeanSwallow:

            def __enter__(self): pass

            def __exit__(self, t, v, tb): return False

        try:

            with EuropeanSwallow():

                1/0

        except ZeroDivisionError:

            pass

        else:

            self.fail("ZeroDivisionError should have been raised")





def test_main():

    run_unittest(FailureTestCase, NonexceptionalTestCase,

                 NestedNonexceptionalTestCase, ExceptionalTestCase,

                 NonLocalFlowControlTestCase,

                 AssignmentTargetTestCase,

                 ExitSwallowsExceptionTestCase)





if __name__ == '__main__':

    test_main()

