Source code for docarray.array.mixins.io.binary

import base64
import io
import os
import os.path
import pickle
from contextlib import nullcontext
from pathlib import Path
from typing import Union, BinaryIO, TYPE_CHECKING, Type, Optional, Generator

from docarray.helper import (
    get_compress_ctx,
    decompress_bytes,
    protocol_and_compress_from_file_path,
)

if TYPE_CHECKING:
    from docarray.typing import T
    from docarray.proto.docarray_pb2 import DocumentArrayProto
    from docarray import Document, DocumentArray


[docs]class LazyRequestReader: def __init__(self, r): self._data = r.iter_content(chunk_size=1024 * 1024) self.content = b'' def __getitem__(self, item: slice): while len(self.content) < item.stop: try: self.content += next(self._data) except StopIteration: return self.content[item.start : -1 : item.step] return self.content[item]
[docs]class BinaryIOMixin: """Save/load an array to a binary file."""
[docs] @classmethod def load_binary( cls: Type['T'], file: Union[str, BinaryIO, bytes, Path], protocol: str = 'pickle-array', compress: Optional[str] = None, _show_progress: bool = False, streaming: bool = False, *args, **kwargs, ) -> Union['DocumentArray', Generator['Document', None, None]]: """Load array elements from a compressed binary file. :param file: File or filename or serialized bytes where the data is stored. :param protocol: protocol to use :param compress: compress algorithm to use :param _show_progress: show progress bar, only works when protocol is `pickle` or `protobuf` :param streaming: if `True` returns a generator over `Document` objects. In case protocol is pickle the `Documents` are streamed from disk to save memory usage :return: a DocumentArray object .. note:: If `file` is `str` it can specify `protocol` and `compress` as file extensions. This functionality assumes `file=file_name.$protocol.$compress` where `$protocol` and `$compress` refer to a string interpolation of the respective `protocol` and `compress` methods. For example if `file=my_docarray.protobuf.lz4` then the binary data will be loaded assuming `protocol=protobuf` and `compress=lz4`. """ if isinstance(file, (io.BufferedReader, LazyRequestReader)): file_ctx = nullcontext(file) elif isinstance(file, bytes): file_ctx = nullcontext(file) # by checking path existence we allow file to be of type Path, LocalPath, PurePath and str elif os.path.exists(file): protocol, compress = protocol_and_compress_from_file_path( file, protocol, compress ) file_ctx = open(file, 'rb') else: raise FileNotFoundError(f'cannot find file {file}') if streaming: return cls._load_binary_stream( file_ctx, protocol=protocol, compress=compress, _show_progress=_show_progress, ) else: return cls._load_binary_all( file_ctx, protocol, compress, _show_progress, *args, **kwargs )
@classmethod def _load_binary_stream( cls: Type['T'], file_ctx: str, protocol=None, compress=None, _show_progress=False, ) -> Generator['Document', None, None]: """Yield `Document` objects from a binary file :param protocol: protocol to use :param compress: compress algorithm to use :param _show_progress: show progress bar, only works when protocol is `pickle` or `protobuf` :return: a generator of `Document` objects """ from docarray import Document from docarray.array.mixins.io.pbar import get_progressbar from rich import filesize with file_ctx as f: version_numdocs_lendoc0 = f.read(9) # 1 byte (uint8) version = int.from_bytes(version_numdocs_lendoc0[0:1], 'big', signed=False) # 8 bytes (uint64) num_docs = int.from_bytes(version_numdocs_lendoc0[1:9], 'big', signed=False) pbar, t = get_progressbar( 'Deserializing', disable=not _show_progress, total=num_docs ) with pbar: _total_size = 0 pbar.start_task(t) for _ in range(num_docs): # 4 bytes (uint32) len_current_doc_in_bytes = int.from_bytes( f.read(4), 'big', signed=False ) _total_size += len_current_doc_in_bytes yield Document.from_bytes( f.read(len_current_doc_in_bytes), protocol=protocol, compress=compress, ) pbar.update( t, advance=1, total_size=str(filesize.decimal(_total_size)) ) @classmethod def _load_binary_all( cls, file_ctx, protocol, compress, show_progress, *args, **kwargs ): """Read a `DocumentArray` object from a binary file :param protocol: protocol to use :param compress: compress algorithm to use :param _show_progress: show progress bar, only works when protocol is `pickle` or `protobuf` :return: a `DocumentArray` """ from docarray import Document with file_ctx as fp: d = fp.read() if hasattr(fp, 'read') else fp if protocol == 'pickle-array' or protocol == 'protobuf-array': if get_compress_ctx(algorithm=compress) is not None: d = decompress_bytes(d, algorithm=compress) compress = None if protocol == 'protobuf-array': from docarray.proto.docarray_pb2 import DocumentArrayProto dap = DocumentArrayProto() dap.ParseFromString(d) return cls.from_protobuf(dap) elif protocol == 'pickle-array': return pickle.loads(d) # Binary format for streaming case else: from rich import filesize from docarray.array.mixins.io.pbar import get_progressbar # 1 byte (uint8) version = int.from_bytes(d[0:1], 'big', signed=False) # 8 bytes (uint64) num_docs = int.from_bytes(d[1:9], 'big', signed=False) pbar, t = get_progressbar( 'Deserializing', disable=not show_progress, total=num_docs ) # this 9 is version + num_docs bytes used start_pos = 9 docs = [] with pbar: _total_size = 0 pbar.start_task(t) for _ in range(num_docs): # 4 bytes (uint32) len_current_doc_in_bytes = int.from_bytes( d[start_pos : start_pos + 4], 'big', signed=False ) start_doc_pos = start_pos + 4 end_doc_pos = start_doc_pos + len_current_doc_in_bytes start_pos = end_doc_pos # variable length bytes doc doc = Document.from_bytes( d[start_doc_pos:end_doc_pos], protocol=protocol, compress=compress, ) docs.append(doc) _total_size += len_current_doc_in_bytes pbar.update( t, advance=1, total_size=str(filesize.decimal(_total_size)) ) return cls(docs, *args, **kwargs)
[docs] @classmethod def from_bytes( cls: Type['T'], data: bytes, protocol: str = 'pickle-array', compress: Optional[str] = None, _show_progress: bool = False, *args, **kwargs, ) -> 'T': return cls.load_binary( data, protocol=protocol, compress=compress, _show_progress=_show_progress, *args, **kwargs, )
[docs] def save_binary( self, file: Union[str, BinaryIO], protocol: str = 'pickle-array', compress: Optional[str] = None, ) -> None: """Save array elements into a binary file. :param file: File or filename to which the data is saved. :param protocol: protocol to use :param compress: compress algorithm to use .. note:: If `file` is `str` it can specify `protocol` and `compress` as file extensions. This functionality assumes `file=file_name.$protocol.$compress` where `$protocol` and `$compress` refer to a string interpolation of the respective `protocol` and `compress` methods. For example if `file=my_docarray.protobuf.lz4` then the binary data will be created using `protocol=protobuf` and `compress=lz4`. Comparing to :meth:`save_json`, it is faster and the file is smaller, but not human-readable. .. note:: To get a binary presentation in memory, use ``bytes(...)``. """ if isinstance(file, io.BufferedWriter): file_ctx = nullcontext(file) else: _protocol, _compress = protocol_and_compress_from_file_path(file) if _protocol is not None: protocol = _protocol if _compress is not None: compress = _compress file_ctx = open(file, 'wb') self.to_bytes(protocol=protocol, compress=compress, _file_ctx=file_ctx)
[docs] def to_bytes( self, protocol: str = 'pickle-array', compress: Optional[str] = None, _file_ctx: Optional[BinaryIO] = None, _show_progress: bool = False, ) -> bytes: """Serialize itself into bytes. For more Pythonic code, please use ``bytes(...)``. :param _file_ctx: File or filename or serialized bytes where the data is stored. :param protocol: protocol to use :param compress: compress algorithm to use :param _show_progress: show progress bar, only works when protocol is `pickle` or `protobuf` :return: the binary serialization in bytes """ if protocol == 'protobuf-array' or protocol == 'pickle-array': compress_ctx = get_compress_ctx(compress, mode='wb') else: # delegate the compression to per-doc compression compress_ctx = None with (_file_ctx or io.BytesIO()) as bf: if compress_ctx is None: # if compress do not support streaming then postpone the compress # into the for-loop f, fc = bf, nullcontext() else: f = compress_ctx(bf) fc = f compress = None with fc: if protocol == 'protobuf-array': f.write(self.to_protobuf().SerializePartialToString()) elif protocol == 'pickle-array': f.write(pickle.dumps(self)) elif protocol in ('pickle', 'protobuf'): from rich import filesize from docarray.array.mixins.io.pbar import get_progressbar pbar, t = get_progressbar( 'Serializing', disable=not _show_progress, total=len(self) ) f.write(self._stream_header) with pbar: _total_size = 0 pbar.start_task(t) for d in self: r = d._to_stream_bytes(protocol=protocol, compress=compress) f.write(r) _total_size += len(r) pbar.update( t, advance=1, total_size=str(filesize.decimal(_total_size)), ) else: raise ValueError( f'protocol={protocol} is not supported. Can be only `protobuf`, `pickle`, `protobuf-array`, `pickle-array`.' ) if not _file_ctx: return bf.getvalue()
[docs] def to_protobuf(self, ndarray_type: Optional[str] = None) -> 'DocumentArrayProto': """Convert DocumentArray into a Protobuf message. :param ndarray_type: can be ``list`` or ``numpy``, if set it will force all ndarray-like object from all Documents to ``List`` or ``numpy.ndarray``. :return: the protobuf message """ from docarray.proto.docarray_pb2 import DocumentArrayProto dap = DocumentArrayProto() for d in self: dap.docs.append(d.to_protobuf(ndarray_type)) return dap
[docs] @classmethod def from_protobuf(cls: Type['T'], pb_msg: 'DocumentArrayProto') -> 'T': from docarray import Document return cls(Document.from_protobuf(od) for od in pb_msg.docs)
def __bytes__(self): return self.to_bytes()
[docs] @classmethod def from_base64( cls: Type['T'], data: str, protocol: str = 'pickle-array', compress: Optional[str] = None, _show_progress: bool = False, *args, **kwargs, ) -> 'T': return cls.load_binary( base64.b64decode(data), protocol=protocol, compress=compress, _show_progress=_show_progress, *args, **kwargs, )
[docs] def to_base64( self, protocol: str = 'pickle-array', compress: Optional[str] = None, _show_progress: bool = False, ) -> str: return base64.b64encode(self.to_bytes(protocol, compress)).decode('utf-8')
@property def _stream_header(self) -> bytes: # Binary format for streaming case # V1 DocArray streaming serialization format # | 1 byte | 8 bytes | 4 bytes | variable | 4 bytes | variable ... # 1 byte (uint8) version_byte = b'\x01' # 8 bytes (uint64) num_docs_as_bytes = len(self).to_bytes(8, 'big', signed=False) return version_byte + num_docs_as_bytes