##############################################################################
#
# Copyright (c) 2008-2009 SIA "KN dati". (http://kndati.lv) All Rights Reserved.
#                    General contacts <info@kndati.lv>
#
# WARNING: This program as such is intended to be used by professional
# programmers who take the whole responsability of assessing all potential
# consequences resulting from its eventual inadequacies and bugs
# End users who are looking for a ready-to-use solution with commercial
# garantees and support are strongly adviced to contract a Free Software
# Service Company
#
# This program is Free Software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
#
##############################################################################

from barcode import barcode
from tools import translate
from domain_parser import domain2statement
from currency_to_text import currency_to_text
import base64
import StringIO
from PIL import Image
import relatorio
import pooler
import time
import osv


class ExtraFunctions:
    """ This class contains some extra functions which
        can be called from the report's template.
    """
    def __init__(self, cr, uid, report_id, context):
        self.cr = cr
        self.uid = uid
        self.pool = pooler.get_pool(self.cr.dbname)
        self.report_id = report_id
        self.context = context
        self.functions = {
            'asarray':self._asarray,
            'asimage':self._asimage,
            'get_attachments':self._get_attachments,
            'get_name':self._get_name,
            'get_label':self._get_label,
            'getLang':self._get_lang,
            'get_selection_item':self._get_selection_item,
            'safe':self._get_safe,
            'countif':self._countif,
            'count':self._count,
            'sumif':self._sumif,
            'sum':self._sum,
            'max':self._max,
            'min':self._min,
            'average':self._average,
            'large':self._large,
            'small':self._small,
            'count_blank':self._count_blank,
            '_':self._translate_text,
            'gettext':self._translate_text,
            'currency_to_text':self._currency2text(context['company'].currency_id.code),
            'barcode':barcode.make_barcode,
            'debugit':self.debugit,
            'dec_to_time':self._dec2time,
            'bool_to_text':self._bool2text,
            'chunks':self._chunks,
            'browse':self._browse,
            'field_size':self._field_size,
            'time':time
        }

    def _get_lang(self):
        return self.context['lang']

    def _bool2text(self, val):
        if val:
            return "YES"
        elif val==False:
            return "NO"
        elif val==None:
            return "empty"

    def _dec2time(self, dec, h_format, min_format):
        if dec==0.0:
            return None
        elif int(dec)==0:
            return min_format.replace('%M', str(int(round((dec-int(dec))*60))))
        elif dec-int(dec)==0.0:
            return h_format.replace('%H', str(int(dec)))
        else:
            return h_format.replace('%H', str(int(dec)))+min_format.replace('%M', str(int(round((dec-int(dec))*60))))

    def _currency2text(self, currency):
        def c_to_text(sum, currency=currency, language=None):
            return unicode(currency_to_text(sum, currency, language or self._get_lang()), "UTF-8")
        return c_to_text

    def _translate_text(self, source):
        trans_obj = self.pool.get('ir.translation')
        trans = trans_obj.search(self.cr,self.uid,[('res_id','=',self.report_id),('type','=','report'),('src','=',source)])
        if not trans:
            trans_obj.create(self.cr, self.uid, {'src':source,'type':'report','lang':self._get_lang(),'res_id':self.report_id,'name':('ir.actions.report.xml,%s' % source)[:128]})
        return translate(self.cr, False, 'report', self._get_lang(), source) or source

    def _countif(self, attr, domain):
        statement = domain2statement(domain)
        expr = "for o in objects:\n\tif%s:\n\t\tcount+=1" % statement
        localspace = {'objects':attr, 'count':0}
        exec expr in localspace
        return localspace['count']

    def _count_blank(self, attr, field):
        expr = "for o in objects:\n\tif not o.%s:\n\t\tcount+=1" % field
        localspace = {'objects':attr, 'count':0}
        exec expr in localspace
        return localspace['count']

    def _count(self, attr):
        return len(attr)

    def _sumif(self, attr, sum_field, domain):
        statement = domain2statement(domain)
        expr = "for o in objects:\n\tif%s:\n\t\tsumm+=float(o.%s)" % (statement, sum_field)
        localspace = {'objects':attr, 'summ':0}
        exec expr in localspace
        return localspace['summ']

    def _sum(self, attr, sum_field):
        expr = "for o in objects:\n\tsumm+=float(o.%s)" % sum_field
        localspace = {'objects':attr, 'summ':0}
        exec expr in localspace
        return localspace['summ']

    def _max(self, attr, field):
        expr = "for o in objects:\n\tvalue_list.append(o.%s)" % field
        localspace = {'objects':attr, 'value_list':[]}
        exec expr in localspace
        return max(localspace['value_list'])

    def _min(self, attr, field):
        expr = "for o in objects:\n\tvalue_list.append(o.%s)" % field
        localspace = {'objects':attr, 'value_list':[]}
        exec expr in localspace
        return min(localspace['value_list'])

    def _average(self, attr, field):
        expr = "for o in objects:\n\tvalue_list.append(o.%s)" % field
        localspace = {'objects':attr, 'value_list':[]}
        exec expr in localspace
        return float(sum(localspace['value_list']))/float(len(localspace['value_list']))

    def _asarray(self, attr, field):
        expr = "for o in objects:\n\tvalue_list.append(o.%s)" % field
        localspace = {'objects':attr, 'value_list':[]}
        exec expr in localspace
        return localspace['value_list']

    def _get_name(self, obj):
        if obj.__class__==osv.orm.browse_record:
            return self.pool.get(obj._table_name).name_get(self.cr, self.uid, [obj.id], {'lang':self._get_lang()})[0][1]
        return ''

    def _get_label(self, obj, field):
        try:
            if getattr(obj, field):
                label = self.pool.get(obj._table_name)._columns[field].string
                return translate(self.cr, False, 'field', self._get_lang(), label) or label
        except Exception:
            return ''

    def _field_size(self, obj, field):
        try:
            if getattr(obj, field):
                size = self.pool.get(obj._table_name)._columns[field].size
                return size
        except Exception:
            return ''

    def _get_selection_item(self, obj, field):
        try:
            field_val = getattr(obj, field)
            if field_val:
                selection = self.pool.get(obj._table_name)._columns[field].selection
                if selection.__class__==list:
                    val_dict = dict(selection)
                else:
                    val_dict = dict(selection(self.pool.get(obj._table_name), self.cr, self.uid, {'lang':self._get_lang()}))
                return val_dict[field_val]
            return ''
        except Exception:
            return ''

    def _get_attachments(self, o, index=None):
        attach_obj = self.pool.get('ir.attachment')
        srch_param = [('res_model','=',o._name),('res_id','=',o.id)]
        if type(index)==str:
            srch_param.append(('name','=',index))
        attachments = attach_obj.search(self.cr,self.uid,srch_param)
        res = [x['datas'] for x in attach_obj.read(self.cr,self.uid,attachments,['datas']) if x['datas']]
        if type(index)==int:
            return res[index]
        return len(res)==1 and res[0] or res

    def _asimage(self, field_value, rotate=None):
        if not field_value:
            return StringIO.StringIO(), 'image/png'
        field_value = base64.decodestring(field_value)
        tf = StringIO.StringIO(field_value)
        tf.seek(0)
        im=Image.open(tf)
        try:
            if rotate!=None:
                im=im.rotate(int(rotate))
                tf.seek(0)
                im.save(tf, im.format.lower())
        except Exception, e:
            pass
        if relatorio.__version__ >= '0.5.2':
            size_x = str(im.size[0]/96.0)+'in'
            size_y = str(im.size[1]/96.0)+'in'
            return tf, 'image/%s' % im.format.lower(), size_x, size_y
        else:
            return tf, 'image/%s' % im.format.lower()

    def _large(self, attr, field, n):
        array=self._asarray(attr, field)
        try:
            n-=1
            while(n):
                array.remove(max(array))
                n-=1
            return max(array)
        except ValueError, e:
            return None

    def _small(self, attr, field, n):
        array=self._asarray(attr, field)
        try:
            n-=1
            while(n):
                array.remove(min(array))
                n-=1
            return min(array)
        except ValueError, e:
            return None

    def _chunks(self, l, n):
        """ Yield successive n-sized chunks from l.
        """
        for i in xrange(0, len(l), n):
            yield l[i:i+n]

    def _browse(self, *args):
        if not args or (args and not args[0]):
            return None
        if len(args)==1:
            model, id = args[0].split(',')
            id = int(id)
        elif len(args)==2:
            model, id = args
        else:
            raise None
        return self.pool.get(model).browse(self.cr, self.uid, id)

    def _get_safe(self, expression, obj):
        try:
            return eval(expression, {'o':obj})
        except Exception, e:
            return None

    def debugit(self, object):
        """ Run the server from command line and 
            call 'debugit' from the template to inspect variables.
        """
        import pdb;pdb.set_trace()
        return

