import sys

import imp

import os

import unittest

from test import test_support





test_src = """\

def get_name():

    return __name__

def get_file():

    return __file__

"""



absimp = "import sub\n"

relimp = "from . import sub\n"

deeprelimp = "from .... import sub\n"

futimp = "from __future__ import absolute_import\n"



reload_src = test_src+"""\

reloaded = True

"""



test_co = compile(test_src, "<???>", "exec")

reload_co = compile(reload_src, "<???>", "exec")



test2_oldabs_co = compile(absimp + test_src, "<???>", "exec")

test2_newabs_co = compile(futimp + absimp + test_src, "<???>", "exec")

test2_newrel_co = compile(relimp + test_src, "<???>", "exec")

test2_deeprel_co = compile(deeprelimp + test_src, "<???>", "exec")

test2_futrel_co = compile(futimp + relimp + test_src, "<???>", "exec")



test_path = "!!!_test_!!!"





class TestImporter:



    modules = {

        "hooktestmodule": (False, test_co),

        "hooktestpackage": (True, test_co),

        "hooktestpackage.sub": (True, test_co),

        "hooktestpackage.sub.subber": (True, test_co),

        "hooktestpackage.oldabs": (False, test2_oldabs_co),

        "hooktestpackage.newabs": (False, test2_newabs_co),

        "hooktestpackage.newrel": (False, test2_newrel_co),

        "hooktestpackage.sub.subber.subest": (True, test2_deeprel_co),

        "hooktestpackage.futrel": (False, test2_futrel_co),

        "sub": (False, test_co),

        "reloadmodule": (False, test_co),

    }



    def __init__(self, path=test_path):

        if path != test_path:

            # if out class is on sys.path_hooks, we must raise

            # ImportError for any path item that we can't handle.

            raise ImportError

        self.path = path



    def _get__path__(self):

        raise NotImplementedError



    def find_module(self, fullname, path=None):

        if fullname in self.modules:

            return self

        else:

            return None



    def load_module(self, fullname):

        ispkg, code = self.modules[fullname]

        mod = sys.modules.setdefault(fullname,imp.new_module(fullname))

        mod.__file__ = "<%s>" % self.__class__.__name__

        mod.__loader__ = self

        if ispkg:

            mod.__path__ = self._get__path__()

        exec code in mod.__dict__

        return mod





class MetaImporter(TestImporter):

    def _get__path__(self):

        return []



class PathImporter(TestImporter):

    def _get__path__(self):

        return [self.path]





class ImportBlocker:

    """Place an ImportBlocker instance on sys.meta_path and you

    can be sure the modules you specified can't be imported, even

    if it's a builtin."""

    def __init__(self, *namestoblock):

        self.namestoblock = dict.fromkeys(namestoblock)

    def find_module(self, fullname, path=None):

        if fullname in self.namestoblock:

            return self

        return None

    def load_module(self, fullname):

        raise ImportError, "I dare you"





class ImpWrapper:



    def __init__(self, path=None):

        if path is not None and not os.path.isdir(path):

            raise ImportError

        self.path = path



    def find_module(self, fullname, path=None):

        subname = fullname.split(".")[-1]

        if subname != fullname and self.path is None:

            return None

        if self.path is None:

            path = None

        else:

            path = [self.path]

        try:

            file, filename, stuff = imp.find_module(subname, path)

        except ImportError:

            return None

        return ImpLoader(file, filename, stuff)





class ImpLoader:



    def __init__(self, file, filename, stuff):

        self.file = file

        self.filename = filename

        self.stuff = stuff



    def load_module(self, fullname):

        mod = imp.load_module(fullname, self.file, self.filename, self.stuff)

        if self.file:

            self.file.close()

        mod.__loader__ = self  # for introspection

        return mod





class ImportHooksBaseTestCase(unittest.TestCase):



    def setUp(self):

        self.path = sys.path[:]

        self.meta_path = sys.meta_path[:]

        self.path_hooks = sys.path_hooks[:]

        sys.path_importer_cache.clear()

        self.modules_before = sys.modules.copy()



    def tearDown(self):

        sys.path[:] = self.path

        sys.meta_path[:] = self.meta_path

        sys.path_hooks[:] = self.path_hooks

        sys.path_importer_cache.clear()

        sys.modules.clear()

        sys.modules.update(self.modules_before)





class ImportHooksTestCase(ImportHooksBaseTestCase):



    def doTestImports(self, importer=None):

        import hooktestmodule

        import hooktestpackage

        import hooktestpackage.sub

        import hooktestpackage.sub.subber

        self.assertEqual(hooktestmodule.get_name(),

                         "hooktestmodule")

        self.assertEqual(hooktestpackage.get_name(),

                         "hooktestpackage")

        self.assertEqual(hooktestpackage.sub.get_name(),

                         "hooktestpackage.sub")

        self.assertEqual(hooktestpackage.sub.subber.get_name(),

                         "hooktestpackage.sub.subber")

        if importer:

            self.assertEqual(hooktestmodule.__loader__, importer)

            self.assertEqual(hooktestpackage.__loader__, importer)

            self.assertEqual(hooktestpackage.sub.__loader__, importer)

            self.assertEqual(hooktestpackage.sub.subber.__loader__, importer)



        TestImporter.modules['reloadmodule'] = (False, test_co)

        import reloadmodule

        self.failIf(hasattr(reloadmodule,'reloaded'))



        TestImporter.modules['reloadmodule'] = (False, reload_co)

        reload(reloadmodule)

        self.failUnless(hasattr(reloadmodule,'reloaded'))



        import hooktestpackage.oldabs

        self.assertEqual(hooktestpackage.oldabs.get_name(),

                         "hooktestpackage.oldabs")

        self.assertEqual(hooktestpackage.oldabs.sub,

                         hooktestpackage.sub)



        import hooktestpackage.newrel

        self.assertEqual(hooktestpackage.newrel.get_name(),

                         "hooktestpackage.newrel")

        self.assertEqual(hooktestpackage.newrel.sub,

                         hooktestpackage.sub)



        import hooktestpackage.sub.subber.subest as subest

        self.assertEqual(subest.get_name(),

                         "hooktestpackage.sub.subber.subest")

        self.assertEqual(subest.sub,

                         hooktestpackage.sub)



        import hooktestpackage.futrel

        self.assertEqual(hooktestpackage.futrel.get_name(),

                         "hooktestpackage.futrel")

        self.assertEqual(hooktestpackage.futrel.sub,

                         hooktestpackage.sub)



        import sub

        self.assertEqual(sub.get_name(), "sub")



        import hooktestpackage.newabs

        self.assertEqual(hooktestpackage.newabs.get_name(),

                         "hooktestpackage.newabs")

        self.assertEqual(hooktestpackage.newabs.sub, sub)



    def testMetaPath(self):

        i = MetaImporter()

        sys.meta_path.append(i)

        self.doTestImports(i)



    def testPathHook(self):

        sys.path_hooks.append(PathImporter)

        sys.path.append(test_path)

        self.doTestImports()



    def testBlocker(self):

        mname = "exceptions"  # an arbitrary harmless builtin module

        if mname in sys.modules:

            del sys.modules[mname]

        sys.meta_path.append(ImportBlocker(mname))

        try:

            __import__(mname)

        except ImportError:

            pass

        else:

            self.fail("'%s' was not supposed to be importable" % mname)



    def testImpWrapper(self):

        i = ImpWrapper()

        sys.meta_path.append(i)

        sys.path_hooks.append(ImpWrapper)

        mnames = ("colorsys", "urlparse", "distutils.core", "compiler.misc")

        for mname in mnames:

            parent = mname.split(".")[0]

            for n in sys.modules.keys():

                if n.startswith(parent):

                    del sys.modules[n]

        for mname in mnames:

            m = __import__(mname, globals(), locals(), ["__dummy__"])

            m.__loader__  # to make sure we actually handled the import





def test_main():

    test_support.run_unittest(ImportHooksTestCase)



if __name__ == "__main__":

    test_main()

