y09_2047.py 4.26 KB
# -*- coding: utf-8 -*-
from __future__ import unicode_literals

import sys
import gzip
import marshal
from math import log

from ..utils import frequency


class CharacterBasedGenerativeModel(object):

    def __init__(self):
        self.l1 = 0.0
        self.l2 = 0.0
        self.l3 = 0.0
        self.status = ('b', 'm', 'e', 's')
        self.uni = frequency.NormalProb()
        self.bi = frequency.NormalProb()
        self.tri = frequency.NormalProb()

    def save(self, fname, iszip=True):
        d = {}
        for k, v in self.__dict__.items():
            if hasattr(v, '__dict__'):
                d[k] = v.__dict__
            else:
                d[k] = v
        if sys.version_info[0] == 3:
            fname = fname + '.3'
        if not iszip:
            marshal.dump(d, open(fname, 'wb'))
        else:
            f = gzip.open(fname, 'wb')
            f.write(marshal.dumps(d))
            f.close()

    def load(self, fname, iszip=True):
        if sys.version_info[0] == 3:
            fname = fname + '.3'
        if not iszip:
            d = marshal.load(open(fname, 'rb'))
        else:
            try:
                f = gzip.open(fname, 'rb')
                d = marshal.loads(f.read())
            except IOError:
                f = open(fname, 'rb')
                d = marshal.loads(f.read())
            f.close()
        for k, v in d.items():
            if hasattr(self.__dict__[k], '__dict__'):
                self.__dict__[k].__dict__ = v
            else:
                self.__dict__[k] = v

    def div(self, v1, v2):
        if v2 == 0:
            return 0
        return float(v1)/v2

    def train(self, data):
        for sentence in data:
            now = [('', 'BOS'), ('', 'BOS')]
            self.bi.add((('', 'BOS'), ('', 'BOS')), 1)
            self.uni.add(('', 'BOS'), 2)
            for word, tag in sentence:
                now.append((word, tag))
                self.uni.add((word, tag), 1)
                self.bi.add(tuple(now[1:]), 1)
                self.tri.add(tuple(now), 1)
                now.pop(0)
        tl1 = 0.0
        tl2 = 0.0
        tl3 = 0.0
        samples = sorted(self.tri.samples(), key=lambda x: self.tri.get(x)[1])
        for now in samples:
            c3 = self.div(self.tri.get(now)[1]-1, self.bi.get(now[:2])[1]-1)
            c2 = self.div(self.bi.get(now[1:])[1]-1, self.uni.get(now[1])[1]-1)
            c1 = self.div(self.uni.get(now[2])[1]-1, self.uni.getsum()-1)
            if c3 >= c1 and c3 >= c2:
                tl3 += self.tri.get(now)[1]
            elif c2 >= c1 and c2 >= c3:
                tl2 += self.tri.get(now)[1]
            elif c1 >= c2 and c1 >= c3:
                tl1 += self.tri.get(now)[1]
        self.l1 = self.div(tl1, tl1+tl2+tl3)
        self.l2 = self.div(tl2, tl1+tl2+tl3)
        self.l3 = self.div(tl3, tl1+tl2+tl3)

    def log_prob(self, s1, s2, s3):
        uni = self.l1*self.uni.freq(s3)
        bi = self.div(self.l2*self.bi.get((s2, s3))[1], self.uni.get(s2)[1])
        tri = self.div(self.l3*self.tri.get((s1, s2, s3))[1],
                       self.bi.get((s1, s2))[1])
        if uni+bi+tri == 0:
            return float('-inf')
        return log(uni+bi+tri)

    def tag(self, data):
        now = [((('', 'BOS'), ('', 'BOS')), 0.0, [])]
        for w in data:
            stage = {}
            not_found = True
            for s in self.status:
                if self.uni.freq((w, s)) != 0:
                    not_found = False
                    break
            if not_found:
                for s in self.status:
                    for pre in now:
                        stage[(pre[0][1], (w, s))] = (pre[1], pre[2]+[s])
                now = list(map(lambda x: (x[0], x[1][0], x[1][1]),
                               stage.items()))
                continue
            for s in self.status:
                for pre in now:
                    p = pre[1]+self.log_prob(pre[0][0], pre[0][1], (w, s))
                    if (not (pre[0][1],
                             (w, s)) in stage) or p > stage[(pre[0][1],
                                                            (w, s))][0]:
                        stage[(pre[0][1], (w, s))] = (p, pre[2]+[s])
            now = list(map(lambda x: (x[0], x[1][0], x[1][1]), stage.items()))
        return zip(data, max(now, key=lambda x: x[1])[2])