# This file is part of the Open Data Cube, see https://opendatacube.org for more information
#
# Copyright (c) 2015-2025 ODC Contributors
# SPDX-License-Identifier: Apache-2.0
""" Dask Distributed Tools
"""
from typing import Any, Iterable, Optional, Union, Tuple
from random import randint
import toolz # type: ignore[import]
import queue
from dask.distributed import Client
import dask
import threading
import logging
import os
__all__ = (
"start_local_dask",
"pmap",
"compute_tasks",
"partition_map",
"save_blob_to_file",
"save_blob_to_s3",
)
_LOG = logging.getLogger(__name__)
def get_total_available_memory(check_jupyter_hub=True):
""" Figure out how much memory is available
1. Check MEM_LIMIT environment variable, set by jupyterhub
2. Use hardware information if that not set
"""
if check_jupyter_hub:
mem_limit = os.environ.get('MEM_LIMIT', None)
if mem_limit is not None:
return int(mem_limit)
from psutil import virtual_memory
return virtual_memory().total
def compute_memory_per_worker(n_workers: int = 1,
mem_safety_margin: Optional[Union[str, int]] = None,
memory_limit: Optional[Union[str, int]] = None) -> int:
""" Figure out how much memory to assign per worker.
result can be passed into ``memory_limit=`` parameter of dask worker/cluster/client
"""
from dask.utils import parse_bytes
if isinstance(memory_limit, str):
memory_limit = parse_bytes(memory_limit)
if isinstance(mem_safety_margin, str):
mem_safety_margin = parse_bytes(mem_safety_margin)
if memory_limit is None and mem_safety_margin is None:
total_bytes = get_total_available_memory()
# leave 500Mb or half of all memory if RAM is less than 1 Gb
mem_safety_margin = min(500*(1024*1024), total_bytes//2)
elif memory_limit is None:
total_bytes = get_total_available_memory()
elif mem_safety_margin is None:
total_bytes = memory_limit
mem_safety_margin = 0
else:
total_bytes = memory_limit
return (total_bytes - mem_safety_margin)//n_workers
[docs]
def start_local_dask(n_workers: int = 1,
threads_per_worker: Optional[int] = None,
mem_safety_margin: Optional[Union[str, int]] = None,
memory_limit: Optional[Union[str, int]] = None,
**kw):
"""
Wrapper around ``distributed.Client(..)`` constructor that deals with memory better.
It also configures ``distributed.dashboard.link`` to go over proxy when operating
from behind jupyterhub.
:param n_workers: number of worker processes to launch
:param threads_per_worker: number of threads per worker, default is as many as there are CPUs
:param memory_limit: maximum memory to use across all workers
:param mem_safety_margin: bytes to reserve for the rest of the system, only applicable
if ``memory_limit=`` is not supplied.
.. note::
if ``memory_limit=`` is supplied, it will be parsed and divided equally between workers.
"""
# if dashboard.link set to default value and running behind hub, make dashboard link go via proxy
if dask.config.get("distributed.dashboard.link") == '{scheme}://{host}:{port}/status':
jup_prefix = os.environ.get('JUPYTERHUB_SERVICE_PREFIX')
if jup_prefix is not None:
jup_prefix = jup_prefix.rstrip('/')
dask.config.set({"distributed.dashboard.link": f"{jup_prefix}/proxy/{{port}}/status"})
memory_limit = compute_memory_per_worker(n_workers=n_workers,
memory_limit=memory_limit,
mem_safety_margin=mem_safety_margin)
client = Client(n_workers=n_workers,
threads_per_worker=threads_per_worker,
memory_limit=memory_limit,
**kw)
return client
def _randomize(prefix):
return '{}-{:08x}'.format(prefix, randint(0, 0xFFFFFFFF))
[docs]
def partition_map(n: int, func: Any, its: Iterable[Any],
name: str = 'compute') -> Iterable[Any]:
"""
Parallel map in lumps.
Partition sequence into lumps of size ``n``, then construct dask delayed computation evaluating to:
.. code-block:: python
[func(x) for x in its[0:1n]],
[func(x) for x in its[n:2n]],
...
[func(x) for x in its[..]],
This is useful when you need to process a large number of small (quick) tasks (pixel drill for example).
:param n: number of elements to process in one go
:param func: Function to apply (non-dask)
:param its: Values to feed to fun
:param name: How the computation should be named in dask visualizations
Returns
-------
Iterator of ``dask.Delayed`` objects.
"""
def lump_proc(dd):
return [func(d) for d in dd]
proc = dask.delayed(lump_proc, nout=1, pure=True)
data_name = _randomize('data_' + name)
name = _randomize(name)
for i, dd in enumerate(toolz.partition_all(n, its)):
lump = dask.delayed(dd,
pure=True,
traverse=False,
name=data_name + str(i))
yield proc(lump, dask_key_name=name + str(i))
[docs]
def compute_tasks(tasks: Iterable[Any], client: Client,
max_in_flight: int = 3) -> Iterable[Any]:
""" Parallel compute stream with back pressure.
Equivalent to:
.. code-block:: python
(client.compute(task).result()
for task in tasks)
but with up to ``max_in_flight`` tasks being processed at the same time.
Input/Output order is preserved, so there is a possibility of head of
line blocking.
.. note::
lower limit is 3 concurrent tasks to simplify implementation,
there is no point calling this function if you want one active
task and supporting exactly 2 active tasks is not worth the complexity,
for now. We might special-case 2 at some point.
"""
# New thread:
# 1. Take dask task from iterator
# 2. Submit to client for processing
# 3. Send it of to wrk_q
#
# Calling thread:
# 1. Pull scheduled future from wrk_q
# 2. Wait for result of the future
# 3. yield result to calling code
from .generic import it2q, qmap
# (max_in_flight - 2) -- one on each side of queue
wrk_q = queue.Queue(maxsize=max(1, max_in_flight - 2)) # type: queue.Queue
# fifo_timeout='0ms' ensures that priority of later tasks is lower
futures = (client.compute(task, fifo_timeout='0ms') for task in tasks)
in_thread = threading.Thread(target=it2q, args=(futures, wrk_q))
in_thread.start()
yield from qmap(lambda f: f.result(), wrk_q)
in_thread.join()
[docs]
def pmap(func: Any,
its: Iterable[Any],
client: Client,
lump: int = 1,
max_in_flight: int = 3,
name: str = 'compute') -> Iterable[Any]:
""" Parallel map with back pressure.
Equivalent to this:
.. code-block:: python
(func(x) for x in its)
Except that ``func(x)`` runs concurrently on dask cluster.
:param func: Method that will be applied concurrently to data from ``its``
:param its: Iterator of input values
:param client: Connected dask client
:param lump: Group this many datasets into one task
:param max_in_flight: Maximum number of active tasks to submit
:param name: Dask name for computation
"""
max_in_flight = max_in_flight // lump
tasks = partition_map(lump, func, its, name=name)
for xx in compute_tasks(tasks, client=client, max_in_flight=max_in_flight):
yield from xx
def _save_blob_to_file(data: Union[bytes, str],
fname: str,
with_deps=None) -> Tuple[str, bool]:
if isinstance(data, str):
data = data.encode('utf8')
try:
with open(fname, 'wb') as f:
f.write(data)
except IOError:
return (fname, False)
return (fname, True)
def _save_blob_to_s3(data: Union[bytes, str],
url: str,
profile: Optional[str] = None,
creds=None,
region_name: Optional[str] = None,
with_deps=None,
**kw) -> Tuple[str, bool]:
from botocore.errorfactory import ClientError
from botocore.exceptions import BotoCoreError
from .aws import s3_dump, s3_client
try:
s3 = s3_client(profile=profile,
creds=creds,
region_name=region_name,
cache=True)
result = s3_dump(data, url, s3=s3, **kw)
except (IOError, BotoCoreError, ClientError):
result = False
return url, result
_save_blob_to_file_delayed = dask.delayed(_save_blob_to_file, name='save-to-disk', pure=False)
_save_blob_to_s3_delayed = dask.delayed(_save_blob_to_s3, name='save-to-s3', pure=False)
[docs]
def save_blob_to_file(data,
fname,
with_deps=None):
"""
Dump from memory to local filesystem as a dask delayed operation.
:param data: Data blob to save to file (have to fit into memory all at once),
strings will be saved in UTF8 format.
:param fname: Path to file
:param with_deps: Useful for introducing dependencies into dask graph,
for example save yaml file after saving all tiff files.
Returns
-------
``(FilePath, True)`` tuple on success
``(FilePath, False)`` on any error
.. note::
Dask workers must be local or have network filesystem mounted in
the same path as calling code.
"""
return _save_blob_to_file_delayed(data, fname, with_deps=with_deps)
[docs]
def save_blob_to_s3(data,
url,
profile=None,
creds=None,
region_name=None,
with_deps=None,
**kw):
"""
Dump from memory to S3 as a dask delayed operation.
:param data: Data blob to save to file (have to fit into memory all at once)
:param url: Url in a form s3://bucket/path/to/file
:param profile: Profile name to lookup (only used if session is not supplied)
:param creds: Override credentials with supplied data
:param region_name: Region name to use, overrides session setting
:param with_deps: Useful for introducing dependencies into dask graph,
for example save yaml file after saving all tiff files.
:param kw: Passed on to ``s3.put_object(..)``, useful for things like ContentType/ACL
Returns
-------
``(url, True)`` tuple on success
``(url, False)`` on any error
"""
return _save_blob_to_s3_delayed(data, url,
profile=profile,
creds=creds,
region_name=region_name,
with_deps=with_deps,
**kw)