Source code for embfile.formats.bin

__all__ = ['BinaryEmbFile', 'BinaryEmbFileReader']

import io
import mmap
from pathlib import Path
from typing import Any, BinaryIO, Dict, Iterable, Optional, Tuple

import numpy
from overrides import overrides

from embfile._utils import noop, progbar
from embfile.compression import open_file
from embfile.core import AbstractEmbFileReader, EmbFile
from embfile.core._file import (DEFAULT_VERBOSE, check_vector_size, glance_first_element,
                                warn_if_wrong_vocab_size)
from embfile.errors import BadEmbFile
from embfile.types import DType, PairsType, PathType, VectorType

#: Default text encoding
DEFAULT_ENCODING = 'utf-8'

#: Default vector data type (little-endian single-precision floating point numbers)
DEFAULT_DTYPE = numpy.dtype('<f4')


def _bom_free_version(encoding: str) -> str:
    """ Given an utf encoding, returns a BOM-free version of it (little-endian version) """

    def utf_aliases(num):
        return {fmt % num for fmt in ['u%d', 'utf%d', 'utf-%d', 'utf_%d']}

    encoding = encoding.lower()
    if encoding in utf_aliases(16):
        return 'utf_16_le'
    if encoding in utf_aliases(32):
        return 'utf_32_le'
    return encoding


def _take_until_delimiter(delim: str,
                          byte_array: bytes,
                          start_position: int,
                          encoding: str,
                          max_bytes: Optional[int] = None) -> Tuple[str, int]:
    """
    Reads the text that precedes a delimiter character and returns it along with the index of the
    first byte after the delimiter.
    Used by _read_until_delimiter, factored out for easier testing (this can be tested on bytes
    array, no need for creating a mmap object).

    Implementation Notes
    --------------------
    Since the encoding could be a multy-byte encoding (e.g. utf-16), we can't simply
    perform a search of the encoded delimiter at the binary level (searching bytes in bytes),
    we have to consider "character boundaries".
    Indeed, if the delimiter is encoded as 'x' and there's one character in the stream
    encoded as 'yx' before the actual delimiter, a simple find('x') would return a
    "false positive".

    Unfortunately, reading and (incrementally) decoding one byte at a time is incredible
    slow in Python. Decoding in chunk is better, but still really slow.

    For this reasons and because the bug described above is very rare (or impossible
    for 1-byte encodings), I first search the delimiter in the "byte space" ignoring
    character boundaries;
    then I check whether decoding the bytes that precede the delimiter produces a
    UnicodeDecodeError due to "truncated data". If it does, we know we found a false
    positives; in that case:

    1. we find the end of the truncated character which starts at ``decode_error.start``
    2. we continue our search of the delimiter bytes after the character end.
    """
    delim_bytes = delim.encode(encoding)

    find_start = start_position
    if max_bytes is not None:
        find_end = start_position + max_bytes

    while True:
        if byte_array[find_start:find_start + 1] == b'':
            return '', find_start

        if max_bytes is None:
            delim_start = byte_array.find(delim_bytes, find_start)
        else:
            delim_start = byte_array.find(delim_bytes, find_start, find_end)

        if delim_start < 0:
            if max_bytes:
                msg = ("expected delimiter %r wasn't found from position %d after %d bytes read "
                       "(max number of bytes allowed)." % (delim, start_position, max_bytes))
            else:
                msg = ("expected delimiter %r wasn't found from position %d to the end of the file."
                       % (delim, start_position))

            raise BadEmbFile(msg)

        delim_end = delim_start + len(delim_bytes)
        try:
            text = byte_array[start_position:delim_start].decode(encoding)
            return text, delim_end
        except UnicodeDecodeError as err:
            # adjust limits of the error to make it relative to byte_array start
            err.start += start_position
            err.end += start_position

            if err.end == delim_start:  # err.reason == "truncated data"
                # False positive: delimiter's bytes starts inside another character's bytes
                # Find the end of the truncated character
                char_start = err.start
                for char_end in range(char_start + 1, char_start + 4):
                    try:
                        byte_array[char_start:char_end].decode(encoding)
                    except UnicodeDecodeError as char_err:
                        char_err.start += char_start
                        char_err.end += char_start

                        if char_err.end != char_end:
                            char_err.object = byte_array
                            raise char_err
                    else:
                        # continue the search of the delimiter after the character
                        find_start = char_end
                        break
            else:
                # the error has nothing to do with our imperfect search of the delimiter
                err.object = byte_array
                raise err


def _read_until_delimiter(delim: str,
                          mem_map: mmap.mmap,
                          encoding: str,
                          max_bytes: Optional[int] = None) -> str:
    """
    Reads the text that precedes a delimiter character and returns it. If there are not bytes to
    read, returns the empty string ''.

    Raises:
        UnicodeDecodeError
        BadEmbFile:
            if the delimiter character is not found before the end of the file (or before
            ``max_bytes``, if provided), it raises ``BadEmbFile`` error.
    """
    text, new_position = _take_until_delimiter(delim, mem_map, mem_map.tell(),  # type: ignore
                                               encoding, max_bytes)
    mem_map.seek(new_position, io.SEEK_SET)
    return text


[docs]class BinaryEmbFileReader(AbstractEmbFileReader): """ :class:`~embfile.core.EmbFileReader` for the binary format. """ #: Conservative upper bound for the length (in bytes) of the header of a binary embedding file _MAX_HEADER_BYTES = 128 #: Conservative upper bound for the length (in bytes) of a word _MAX_WORD_BYTES = 1024 def __init__(self, file_obj: BinaryIO, encoding: str = DEFAULT_ENCODING, dtype: DType = DEFAULT_DTYPE, out_dtype: Optional[DType] = None): super().__init__(out_dtype or dtype) self.dtype = numpy.dtype(dtype) encoding = _bom_free_version(encoding) self.encoding = encoding self._file_obj = file_obj self._mmap = mmap.mmap(file_obj.fileno(), 0, access=mmap.ACCESS_READ) self.header = self._read_header() self._body_start = self._mmap.tell() # store the position where the actual data starts self._vector_size_in_bytes = self.dtype.itemsize * self.header['vector_size']
[docs] @classmethod def from_path(cls, path: PathType, encoding: str = DEFAULT_ENCODING, dtype: DType = DEFAULT_DTYPE, out_dtype: Optional[DType] = None): return cls(open_file(path, 'rb'), encoding=encoding, dtype=dtype, out_dtype=out_dtype)
def _read_header(self) -> Dict[str, Any]: header_line = _read_until_delimiter('\n', self._mmap, self.encoding, self._MAX_HEADER_BYTES) vocab_size, vector_size = map(int, header_line.split()) return {'vocab_size': vocab_size, 'vector_size': vector_size} @overrides def _close(self) -> None: self._mmap.close() self._file_obj.close() @overrides def _reset(self) -> None: self._mmap.seek(self._body_start, io.SEEK_SET) @overrides def _read_word(self) -> str: word = _read_until_delimiter(' ', self._mmap, self.encoding, self._MAX_WORD_BYTES) if not word: raise StopIteration return word @overrides def _read_vector(self) -> VectorType: vec_bytes = self._mmap.read(self._vector_size_in_bytes) vector = numpy.frombuffer(vec_bytes, dtype=self.dtype) return numpy.asarray(vector, dtype=self.out_dtype) @overrides def _skip_vector(self) -> None: self._mmap.seek(self._vector_size_in_bytes, io.SEEK_CUR)
[docs]class BinaryEmbFile(EmbFile): """ Format used by the Google word2vec tool. You can use it to read the file `GoogleNews-vectors-negative300.bin <https://drive.google.com/file/d/0B7XkCwpI5KDYNlNUTTlSS21pQmM/edit?usp=sharing>`_. It begins with a text header line of space-separated fields:: <vocab_size> <vector_size>\n Each word vector pair is encoded as following: - encoded word + space - followed by the binary representation of the vector. Attributes: path encoding dtype out_dtype verbose """ DEFAULT_EXTENSION = '.bin' def __init__(self, path: PathType, encoding: str = DEFAULT_ENCODING, dtype: DType = DEFAULT_DTYPE, out_dtype: Optional[DType] = None, verbose: int = DEFAULT_VERBOSE): """ Args: path: path to the (eventually compressed) file encoding: text encoding; **note:** if you provide an utf encoding (e.g. *utf-16*) that uses a BOM (Byte Order Mark) without specifying the byte-endianness (e.g. *utf-16-le* or *utf-16-be*), the little-endian version is used (*utf-16-le*). dtype: a valid numpy data type (or whatever you can pass to numpy.dtype()) (default: '<f4'; little-endian float, 4 bytes) out_dtype: all the vectors returned will be (eventually) converted to this data type; by default, it is equal to the original data type of the vectors in the file, i.e. no conversion takes place. """ super().__init__(path, out_dtype, verbose=verbose) self.encoding = _bom_free_version(encoding) self.dtype = numpy.dtype(dtype) self.out_dtype = numpy.dtype(out_dtype) if out_dtype else self.dtype # Read the header with self.reader() as reader: self.vocab_size = reader.header['vocab_size'] self.vector_size = reader.header['vector_size'] @overrides def _reader(self) -> BinaryEmbFileReader: return BinaryEmbFileReader.from_path( path=self.path, encoding=self.encoding, dtype=self.dtype, out_dtype=self.out_dtype ) @overrides def _close(self): pass @classmethod def _create(cls, out_path: Path, word_vectors: Iterable[Tuple[str, VectorType]], vector_size: int, vocab_size: Optional[int], compression: Optional[str] = None, verbose: bool = True, encoding: str = DEFAULT_ENCODING, dtype: Optional[DType] = None) -> Path: echo = print if verbose else noop encoding = _bom_free_version(encoding) if not dtype: (_, first_vector), word_vectors = glance_first_element(word_vectors) dtype = first_vector.dtype else: dtype = numpy.dtype(dtype) if not vocab_size: raise ValueError('unable to infer vocab_size; you must manually provide it') with open_file(out_path, 'wb', compression=compression) as file: header_line = '%d %d\n' % (vocab_size, vector_size) echo('Writing the header: %s', header_line) header_bytes = header_line.encode(encoding) file.write(header_bytes) for i, (word, vector) in progbar(enumerate(word_vectors), verbose, total=vocab_size): if ' ' in word: raise ValueError("the word number %d contains one or more spaces: %r" % (i, word)) file.write((word + ' ').encode(encoding)) check_vector_size(i, vector, vector_size) file.write(numpy.asarray(vector, dtype).tobytes()) warn_if_wrong_vocab_size(vocab_size, actual_size=i + 1, extra_info='As a consequence, the header of the file has a wrong ' 'vocab_size') return out_path
[docs] @classmethod def create(cls, out_path: PathType, word_vectors: PairsType, vocab_size: Optional[int] = None, compression: Optional[str] = None, verbose: bool = True, overwrite: bool = False, encoding: str = DEFAULT_ENCODING, dtype: Optional[DType] = None) -> None: """ Format-specific arguments are ``encoding`` and ``dtype``. **Note:** all the text is encoded without BOM (Byte Order Mark). If you pass "utf-16" or "utf-18", the little-endian version is used (e.g. "utf-16-le") See :meth:`~embfile.core.file.EmbFile.create` for more. """ super().create(out_path, word_vectors, vocab_size, compression, verbose, overwrite, encoding=encoding, dtype=dtype)
[docs] @classmethod def create_from_file(cls, source_file: 'EmbFile', out_dir: Optional[PathType] = None, out_filename: Optional[str] = None, vocab_size: Optional[int] = None, compression: Optional[str] = None, verbose: bool = True, overwrite: bool = False, encoding: str = DEFAULT_ENCODING, dtype: Optional[DType] = None) -> Path: """ Format-specific arguments are ``encoding`` and ``dtype``. **Note:** all the text is encoded without BOM (Byte Order Mark). If you pass "utf-16" or "utf-18", the little-endian version is used (e.g. "utf-16-le") See :meth:`~embfile.core.file.EmbFile.create_from_file` for more. """ return super().create_from_file( source_file, out_dir, out_filename, vocab_size, compression, verbose, overwrite, encoding=encoding, dtype=dtype)