#-*- coding: ISO-8859-1 -*-

# pysqlite2/test/userfunctions.py: tests for user-defined functions and

#                                  aggregates.

#

# Copyright (C) 2005-2007 Gerhard Hring <gh@ghaering.de>

#

# This file is part of pysqlite.

#

# This software is provided 'as-is', without any express or implied

# warranty.  In no event will the authors be held liable for any damages

# arising from the use of this software.

#

# Permission is granted to anyone to use this software for any purpose,

# including commercial applications, and to alter it and redistribute it

# freely, subject to the following restrictions:

#

# 1. The origin of this software must not be misrepresented; you must not

#    claim that you wrote the original software. If you use this software

#    in a product, an acknowledgment in the product documentation would be

#    appreciated but is not required.

# 2. Altered source versions must be plainly marked as such, and must not be

#    misrepresented as being the original software.

# 3. This notice may not be removed or altered from any source distribution.



import unittest

import sqlite3 as sqlite



def func_returntext():

    return "foo"

def func_returnunicode():

    return u"bar"

def func_returnint():

    return 42

def func_returnfloat():

    return 3.14

def func_returnnull():

    return None

def func_returnblob():

    return buffer("blob")

def func_raiseexception():

    5/0



def func_isstring(v):

    return type(v) is unicode

def func_isint(v):

    return type(v) is int

def func_isfloat(v):

    return type(v) is float

def func_isnone(v):

    return type(v) is type(None)

def func_isblob(v):

    return type(v) is buffer



class AggrNoStep:

    def __init__(self):

        pass



    def finalize(self):

        return 1



class AggrNoFinalize:

    def __init__(self):

        pass



    def step(self, x):

        pass



class AggrExceptionInInit:

    def __init__(self):

        5/0



    def step(self, x):

        pass



    def finalize(self):

        pass



class AggrExceptionInStep:

    def __init__(self):

        pass



    def step(self, x):

        5/0



    def finalize(self):

        return 42



class AggrExceptionInFinalize:

    def __init__(self):

        pass



    def step(self, x):

        pass



    def finalize(self):

        5/0



class AggrCheckType:

    def __init__(self):

        self.val = None



    def step(self, whichType, val):

        theType = {"str": unicode, "int": int, "float": float, "None": type(None), "blob": buffer}

        self.val = int(theType[whichType] is type(val))



    def finalize(self):

        return self.val



class AggrSum:

    def __init__(self):

        self.val = 0.0



    def step(self, val):

        self.val += val



    def finalize(self):

        return self.val



class FunctionTests(unittest.TestCase):

    def setUp(self):

        self.con = sqlite.connect(":memory:")



        self.con.create_function("returntext", 0, func_returntext)

        self.con.create_function("returnunicode", 0, func_returnunicode)

        self.con.create_function("returnint", 0, func_returnint)

        self.con.create_function("returnfloat", 0, func_returnfloat)

        self.con.create_function("returnnull", 0, func_returnnull)

        self.con.create_function("returnblob", 0, func_returnblob)

        self.con.create_function("raiseexception", 0, func_raiseexception)



        self.con.create_function("isstring", 1, func_isstring)

        self.con.create_function("isint", 1, func_isint)

        self.con.create_function("isfloat", 1, func_isfloat)

        self.con.create_function("isnone", 1, func_isnone)

        self.con.create_function("isblob", 1, func_isblob)



    def tearDown(self):

        self.con.close()



    def CheckFuncErrorOnCreate(self):

        try:

            self.con.create_function("bla", -100, lambda x: 2*x)

            self.fail("should have raised an OperationalError")

        except sqlite.OperationalError:

            pass



    def CheckFuncRefCount(self):

        def getfunc():

            def f():

                return 1

            return f

        f = getfunc()

        globals()["foo"] = f

        # self.con.create_function("reftest", 0, getfunc())

        self.con.create_function("reftest", 0, f)

        cur = self.con.cursor()

        cur.execute("select reftest()")



    def CheckFuncReturnText(self):

        cur = self.con.cursor()

        cur.execute("select returntext()")

        val = cur.fetchone()[0]

        self.failUnlessEqual(type(val), unicode)

        self.failUnlessEqual(val, "foo")



    def CheckFuncReturnUnicode(self):

        cur = self.con.cursor()

        cur.execute("select returnunicode()")

        val = cur.fetchone()[0]

        self.failUnlessEqual(type(val), unicode)

        self.failUnlessEqual(val, u"bar")



    def CheckFuncReturnInt(self):

        cur = self.con.cursor()

        cur.execute("select returnint()")

        val = cur.fetchone()[0]

        self.failUnlessEqual(type(val), int)

        self.failUnlessEqual(val, 42)



    def CheckFuncReturnFloat(self):

        cur = self.con.cursor()

        cur.execute("select returnfloat()")

        val = cur.fetchone()[0]

        self.failUnlessEqual(type(val), float)

        if val < 3.139 or val > 3.141:

            self.fail("wrong value")



    def CheckFuncReturnNull(self):

        cur = self.con.cursor()

        cur.execute("select returnnull()")

        val = cur.fetchone()[0]

        self.failUnlessEqual(type(val), type(None))

        self.failUnlessEqual(val, None)



    def CheckFuncReturnBlob(self):

        cur = self.con.cursor()

        cur.execute("select returnblob()")

        val = cur.fetchone()[0]

        self.failUnlessEqual(type(val), buffer)

        self.failUnlessEqual(val, buffer("blob"))



    def CheckFuncException(self):

        cur = self.con.cursor()

        try:

            cur.execute("select raiseexception()")

            cur.fetchone()

            self.fail("should have raised OperationalError")

        except sqlite.OperationalError, e:

            self.failUnlessEqual(e.args[0], 'user-defined function raised exception')



    def CheckParamString(self):

        cur = self.con.cursor()

        cur.execute("select isstring(?)", ("foo",))

        val = cur.fetchone()[0]

        self.failUnlessEqual(val, 1)



    def CheckParamInt(self):

        cur = self.con.cursor()

        cur.execute("select isint(?)", (42,))

        val = cur.fetchone()[0]

        self.failUnlessEqual(val, 1)



    def CheckParamFloat(self):

        cur = self.con.cursor()

        cur.execute("select isfloat(?)", (3.14,))

        val = cur.fetchone()[0]

        self.failUnlessEqual(val, 1)



    def CheckParamNone(self):

        cur = self.con.cursor()

        cur.execute("select isnone(?)", (None,))

        val = cur.fetchone()[0]

        self.failUnlessEqual(val, 1)



    def CheckParamBlob(self):

        cur = self.con.cursor()

        cur.execute("select isblob(?)", (buffer("blob"),))

        val = cur.fetchone()[0]

        self.failUnlessEqual(val, 1)



class AggregateTests(unittest.TestCase):

    def setUp(self):

        self.con = sqlite.connect(":memory:")

        cur = self.con.cursor()

        cur.execute("""

            create table test(

                t text,

                i integer,

                f float,

                n,

                b blob

                )

            """)

        cur.execute("insert into test(t, i, f, n, b) values (?, ?, ?, ?, ?)",

            ("foo", 5, 3.14, None, buffer("blob"),))



        self.con.create_aggregate("nostep", 1, AggrNoStep)

        self.con.create_aggregate("nofinalize", 1, AggrNoFinalize)

        self.con.create_aggregate("excInit", 1, AggrExceptionInInit)

        self.con.create_aggregate("excStep", 1, AggrExceptionInStep)

        self.con.create_aggregate("excFinalize", 1, AggrExceptionInFinalize)

        self.con.create_aggregate("checkType", 2, AggrCheckType)

        self.con.create_aggregate("mysum", 1, AggrSum)



    def tearDown(self):

        #self.cur.close()

        #self.con.close()

        pass



    def CheckAggrErrorOnCreate(self):

        try:

            self.con.create_function("bla", -100, AggrSum)

            self.fail("should have raised an OperationalError")

        except sqlite.OperationalError:

            pass



    def CheckAggrNoStep(self):

        cur = self.con.cursor()

        try:

            cur.execute("select nostep(t) from test")

            self.fail("should have raised an AttributeError")

        except AttributeError, e:

            self.failUnlessEqual(e.args[0], "AggrNoStep instance has no attribute 'step'")



    def CheckAggrNoFinalize(self):

        cur = self.con.cursor()

        try:

            cur.execute("select nofinalize(t) from test")

            val = cur.fetchone()[0]

            self.fail("should have raised an OperationalError")

        except sqlite.OperationalError, e:

            self.failUnlessEqual(e.args[0], "user-defined aggregate's 'finalize' method raised error")



    def CheckAggrExceptionInInit(self):

        cur = self.con.cursor()

        try:

            cur.execute("select excInit(t) from test")

            val = cur.fetchone()[0]

            self.fail("should have raised an OperationalError")

        except sqlite.OperationalError, e:

            self.failUnlessEqual(e.args[0], "user-defined aggregate's '__init__' method raised error")



    def CheckAggrExceptionInStep(self):

        cur = self.con.cursor()

        try:

            cur.execute("select excStep(t) from test")

            val = cur.fetchone()[0]

            self.fail("should have raised an OperationalError")

        except sqlite.OperationalError, e:

            self.failUnlessEqual(e.args[0], "user-defined aggregate's 'step' method raised error")



    def CheckAggrExceptionInFinalize(self):

        cur = self.con.cursor()

        try:

            cur.execute("select excFinalize(t) from test")

            val = cur.fetchone()[0]

            self.fail("should have raised an OperationalError")

        except sqlite.OperationalError, e:

            self.failUnlessEqual(e.args[0], "user-defined aggregate's 'finalize' method raised error")



    def CheckAggrCheckParamStr(self):

        cur = self.con.cursor()

        cur.execute("select checkType('str', ?)", ("foo",))

        val = cur.fetchone()[0]

        self.failUnlessEqual(val, 1)



    def CheckAggrCheckParamInt(self):

        cur = self.con.cursor()

        cur.execute("select checkType('int', ?)", (42,))

        val = cur.fetchone()[0]

        self.failUnlessEqual(val, 1)



    def CheckAggrCheckParamFloat(self):

        cur = self.con.cursor()

        cur.execute("select checkType('float', ?)", (3.14,))

        val = cur.fetchone()[0]

        self.failUnlessEqual(val, 1)



    def CheckAggrCheckParamNone(self):

        cur = self.con.cursor()

        cur.execute("select checkType('None', ?)", (None,))

        val = cur.fetchone()[0]

        self.failUnlessEqual(val, 1)



    def CheckAggrCheckParamBlob(self):

        cur = self.con.cursor()

        cur.execute("select checkType('blob', ?)", (buffer("blob"),))

        val = cur.fetchone()[0]

        self.failUnlessEqual(val, 1)



    def CheckAggrCheckAggrSum(self):

        cur = self.con.cursor()

        cur.execute("delete from test")

        cur.executemany("insert into test(i) values (?)", [(10,), (20,), (30,)])

        cur.execute("select mysum(i) from test")

        val = cur.fetchone()[0]

        self.failUnlessEqual(val, 60)



def authorizer_cb(action, arg1, arg2, dbname, source):

    if action != sqlite.SQLITE_SELECT:

        return sqlite.SQLITE_DENY

    if arg2 == 'c2' or arg1 == 't2':

        return sqlite.SQLITE_DENY

    return sqlite.SQLITE_OK



class AuthorizerTests(unittest.TestCase):

    def setUp(self):

        self.con = sqlite.connect(":memory:")

        self.con.executescript("""

            create table t1 (c1, c2);

            create table t2 (c1, c2);

            insert into t1 (c1, c2) values (1, 2);

            insert into t2 (c1, c2) values (4, 5);

            """)



        # For our security test:

        self.con.execute("select c2 from t2")



        self.con.set_authorizer(authorizer_cb)



    def tearDown(self):

        pass



    def CheckTableAccess(self):

        try:

            self.con.execute("select * from t2")

        except sqlite.DatabaseError, e:

            if not e.args[0].endswith("prohibited"):

                self.fail("wrong exception text: %s" % e.args[0])

            return

        self.fail("should have raised an exception due to missing privileges")



    def CheckColumnAccess(self):

        try:

            self.con.execute("select c2 from t1")

        except sqlite.DatabaseError, e:

            if not e.args[0].endswith("prohibited"):

                self.fail("wrong exception text: %s" % e.args[0])

            return

        self.fail("should have raised an exception due to missing privileges")



def suite():

    function_suite = unittest.makeSuite(FunctionTests, "Check")

    aggregate_suite = unittest.makeSuite(AggregateTests, "Check")

    authorizer_suite = unittest.makeSuite(AuthorizerTests, "Check")

    return unittest.TestSuite((function_suite, aggregate_suite, authorizer_suite))



def test():

    runner = unittest.TextTestRunner()

    runner.run(suite())



if __name__ == "__main__":

    test()

