Source code for docarray.array.mixins.io.pushpull

import json
import os
import warnings
from functools import lru_cache
from pathlib import Path
from typing import Dict, Type, TYPE_CHECKING, Optional
from urllib.request import Request, urlopen

from ....helper import get_request_header

if TYPE_CHECKING:
    from ....typing import T

JINA_CLOUD_CONFIG = 'config.json'


@lru_cache()
def _get_hub_config() -> Optional[Dict]:
    hub_root = Path(os.environ.get('JINA_HUB_ROOT', Path.home().joinpath('.jina')))

    if not hub_root.exists():
        hub_root.mkdir(parents=True, exist_ok=True)

    config_file = hub_root.joinpath(JINA_CLOUD_CONFIG)
    if config_file.exists():
        with open(config_file) as f:
            return json.load(f)


@lru_cache()
def _get_cloud_api() -> str:
    """Get Cloud Api for transmitting data to the cloud.

    :return: Cloud Api Url
    """
    return os.environ.get('JINA_HUBBLE_REGISTRY', 'https://api.hubble.jina.ai')


[docs]class PushPullMixin: """Transmitting :class:`DocumentArray` via Jina Cloud Service""" _max_bytes = 4 * 1024 * 1024 * 1024
[docs] def push(self, name: str, show_progress: bool = False, public: bool = True) -> Dict: """Push this DocumentArray object to Jina Cloud which can be later retrieved via :meth:`.push` .. note:: - Push with the same ``name`` will override the existing content. - Kinda like a public clipboard where everyone can override anyone's content. So to make your content survive longer, you may want to use longer & more complicated name. - The lifetime of the content is not promised atm, could be a day, could be a week. Do not use it for persistence. Only use this full temporary transmission/storage/clipboard. :param name: a name that later can be used for retrieve this :class:`DocumentArray`. :param show_progress: if to show a progress bar on pulling :param public: If True, the DocumentArray will be shared publicly. Otherwise, it will be private. """ import requests delimiter = os.urandom(32) (data, ctype) = requests.packages.urllib3.filepost.encode_multipart_formdata( { 'file': ( 'DocumentArray', delimiter, ), 'name': name, 'type': 'documentArray', 'public': public, } ) headers = {'Content-Type': ctype, **get_request_header()} _hub_config = _get_hub_config() if _hub_config: auth_token = _hub_config.get('auth_token') headers['Authorization'] = f'token {auth_token}' _head, _tail = data.split(delimiter) _head += self._stream_header from rich import filesize from .pbar import get_progressbar pbar, t = get_progressbar('Pushing', disable=not show_progress, total=len(self)) def gen(): total_size = 0 pbar.start_task(t) yield _head def _get_chunk(_batch): return b''.join( d._to_stream_bytes(protocol='protobuf', compress='gzip') for d in _batch ), len(_batch) for chunk, num_doc_in_chunk in self.map_batch(_get_chunk, batch_size=32): total_size += len(chunk) if total_size > self._max_bytes: warnings.warn( f'DocumentArray is too big. The pushed DocumentArray might be chopped off.' ) break yield chunk pbar.update( t, advance=num_doc_in_chunk, total_size=str(filesize.decimal(total_size)), ) yield _tail with pbar: response = requests.post( f'{_get_cloud_api()}/v2/rpc/artifact.upload', data=gen(), headers=headers, ) if response.ok: return response.json()['data'] else: response.raise_for_status()
[docs] @classmethod def pull( cls: Type['T'], name: str, show_progress: bool = False, local_cache: bool = False, *args, **kwargs, ) -> 'T': """Pulling a :class:`DocumentArray` from Jina Cloud Service to local. :param name: the upload name set during :meth:`.push` :param show_progress: if to show a progress bar on pulling :param local_cache: store the downloaded DocumentArray to local folder :return: a :class:`DocumentArray` object """ import requests headers = {} _hub_config = _get_hub_config() if _hub_config: auth_token = _hub_config.get('auth_token') headers['Authorization'] = f'token {auth_token}' url = f'{_get_cloud_api()}/v2/rpc/artifact.getDownloadUrl?name={name}' response = requests.get(url, headers=headers) if response.ok: url = response.json()['data']['download'] else: response.raise_for_status() with requests.get( url, stream=True, headers=get_request_header(), ) as r: r.raise_for_status() _da_len = int(r.headers['Content-length']) from .binary import LazyRequestReader _source = LazyRequestReader(r) if local_cache and os.path.exists(f'.cache/{name}'): _cache_len = os.path.getsize(f'.cache/{name}') if _cache_len == _da_len: _source = f'.cache/{name}' r = cls.load_binary( _source, protocol='protobuf', compress='gzip', _show_progress=show_progress, *args, **kwargs, ) if isinstance(_source, LazyRequestReader) and local_cache: os.makedirs('.cache', exist_ok=True) with open(f'.cache/{name}', 'wb') as fp: fp.write(_source.content) return r