"""
Stream — the composable pipeline builder.
``Stream`` is the central object. It holds a lazy *recipe* of processes
that are only materialized (channels + tasks spawned) when a terminal
operation is called (``collect``, ``count``, ``reduce``, or via the
``open()`` context manager).
Two composition styles are supported:
1. **Method chaining** (builder pattern)::
result = await (
Stream.from_iterable(items)
.map(process, workers=4)
.filter(is_ok)
.collect()
)
2. **Pipe operator** (``|``)::
from anyiostream import pipe
result = await (
Stream.from_iterable(items)
| pipe.map(process, workers=4)
| pipe.filter(is_ok)
| pipe.collect()
)
Design: The pipeline is executed inside a single ``async with
create_task_group()`` block so that structured concurrency is
preserved — every spawned task is guaranteed to complete (or be
cancelled) before the block exits.
"""
from __future__ import annotations
import math
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable, Iterable
from contextlib import asynccontextmanager
from typing import (
Any,
TypeVar,
)
import anyio
from anyio.streams.memory import (
MemoryObjectReceiveStream,
MemoryObjectSendStream,
)
from anyiostream.process import (
MemoryBuffer,
Process,
ProcessConfig,
ProcessKind,
ResultStages,
)
T = TypeVar("T")
U = TypeVar("U")
# Sentinel for pipe.collect() / pipe.count() — tells __or__ to materialize
_COLLECT_SENTINEL = object()
_COUNT_SENTINEL = object()
_COLLECT_BATCH_SENTINEL = object()
_COLLECT_SPLIT_SENTINEL = object()
# ---------------------------------------------------------------------------
# Stream
# ---------------------------------------------------------------------------
[docs]
class Stream[T](ResultStages):
"""
A lazy, composable async pipeline.
Holds a list of ``Process`` descriptors and a source factory. Nothing
runs until a terminal operation (``collect``, ``count``, or ``open()``)
is invoked.
"""
__slots__ = ("_source_factory", "_processes")
[docs]
def __init__(
self,
source_factory: Callable[[MemoryObjectSendStream[Any]], Awaitable[None]],
processes: list[Process[Any, Any]] | None = None,
) -> None:
self._source_factory = source_factory
self._processes: list[Process[Any, Any]] = processes or []
# ------------------------------------------------------------------
# Constructors
# ------------------------------------------------------------------
[docs]
@classmethod
def from_iterable(
cls,
items: Iterable[T] | AsyncIterable[T],
*,
buffer_size: float = 0,
) -> Stream[T]:
"""
Create a stream from a sync or async iterable.
Args:
items: Source data.
buffer_size: Buffer between source and first process (not used directly
here — the first process's own ``buffer_size`` controls it).
"""
async def _produce(send: MemoryObjectSendStream[T]) -> None:
async with send:
if isinstance(items, AsyncIterable):
async for item in items:
await send.send(item)
else:
for item in items:
await send.send(item)
return cls(_produce)
[docs]
@classmethod
def from_callable(
cls,
factory: Callable[[], Iterable[T] | AsyncIterable[T]],
*,
buffer_size: float = 0,
) -> Stream[T]:
"""
Lazily create a stream: ``factory`` is called at execution time.
Useful when the source is expensive to create or can only be
iterated once.
"""
async def _produce(send: MemoryObjectSendStream[T]) -> None:
async with send:
source = factory()
if isinstance(source, AsyncIterable):
async for item in source:
await send.send(item)
else:
for item in source:
await send.send(item)
return cls(_produce)
# ------------------------------------------------------------------
# Process builders (method-chaining API)
# ------------------------------------------------------------------
[docs]
def map(
self,
func: Callable[[T], U | Awaitable[U]],
*,
workers: int = 1,
buffer_size: float = 0,
max_buffer_bytes: int = 10_000_000,
size_func: Callable[[Any], int] | None = None,
name: str | None = None,
) -> Stream[U]:
"""
1:1 transformation. ``func`` may be sync or async.
Args:
func: Transform function ``T -> U``.
workers: Concurrent workers for this process.
buffer_size: Backpressure buffer to downstream (item count).
max_buffer_bytes: Memory-based buffer limit in bytes.
Defaults to 10MB (10_000_000 bytes).
When set, enables MemoryBuffer for memory-aware buffering.
size_func: Optional function to calculate item size in bytes.
Only used when max_buffer_bytes is specified.
name: Label for tracing.
"""
process: Process[T, U] = Process(
kind=ProcessKind.MAP,
func=func,
config=ProcessConfig(
workers=workers,
buffer_size=buffer_size,
max_buffer_bytes=max_buffer_bytes,
size_func=size_func,
name=name,
),
)
return Stream(self._source_factory, [*self._processes, process])
[docs]
def flat_map(
self,
func: Callable[[T], AsyncIterable[U] | Iterable[U] | Awaitable[Any]],
*,
workers: int = 1,
buffer_size: float = 0,
max_buffer_bytes: int = 10_000_000,
size_func: Callable[[Any], int] | None = None,
name: str | None = None,
) -> Stream[U]:
"""
1:N transformation. ``func`` returns an iterable or async iterable.
Args:
func: Transform function ``T -> Iterable[U]`` or ``T -> AsyncIterable[U]``.
workers: Concurrent workers for this process.
buffer_size: Backpressure buffer to downstream.
max_buffer_bytes: Memory-based buffer limit in bytes.
Defaults to 10MB (10_000_000 bytes).
When set, enables MemoryBuffer for memory-aware buffering.
size_func: Optional function to calculate item size in bytes.
Only used when max_buffer_bytes is specified.
name: Label for tracing.
"""
process: Process[T, U] = Process(
kind=ProcessKind.FLAT_MAP,
func=func,
config=ProcessConfig(
workers=workers,
buffer_size=buffer_size,
max_buffer_bytes=max_buffer_bytes,
size_func=size_func,
name=name,
),
)
return Stream(self._source_factory, [*self._processes, process])
[docs]
def filter(
self,
predicate: Callable[[T], bool | Awaitable[bool]],
*,
workers: int = 1,
buffer_size: float = 0,
max_buffer_bytes: int = 10_000_000,
size_func: Callable[[Any], int] | None = None,
name: str | None = None,
) -> Stream[T]:
"""
Keep only items where ``predicate`` returns truthy.
Args:
predicate: Filter function ``T -> bool``.
workers: Concurrent workers.
buffer_size: Backpressure buffer to downstream.
max_buffer_bytes: Memory-based buffer limit in bytes.
Defaults to 10MB (10_000_000 bytes).
When set, enables MemoryBuffer for memory-aware buffering.
size_func: Optional function to calculate item size in bytes.
Only used when max_buffer_bytes is specified.
name: Label for tracing.
"""
process: Process[T, T] = Process(
kind=ProcessKind.FILTER,
func=predicate,
config=ProcessConfig(
workers=workers,
buffer_size=buffer_size,
max_buffer_bytes=max_buffer_bytes,
size_func=size_func,
name=name,
),
)
return Stream(self._source_factory, [*self._processes, process])
[docs]
def foreach(
self,
func: Callable[[T], Any],
*,
workers: int = 1,
buffer_size: float = 0,
max_buffer_bytes: int = 10_000_000,
size_func: Callable[[Any], int] | None = None,
name: str | None = None,
) -> Stream[T]:
"""
Side-effect process: call ``func`` on each item without changing it.
Useful for logging, metrics, or caching.
Args:
func: Side-effect function ``T -> None``.
workers: Concurrent workers.
buffer_size: Backpressure buffer to downstream.
max_buffer_bytes: Memory-based buffer limit in bytes.
Defaults to 10MB (10_000_000 bytes).
When set, enables MemoryBuffer for memory-aware buffering.
size_func: Optional function to calculate item size in bytes.
Only used when max_buffer_bytes is specified.
name: Label for tracing.
"""
process: Process[T, T] = Process(
kind=ProcessKind.FOREACH,
func=func,
config=ProcessConfig(
workers=workers,
buffer_size=buffer_size,
max_buffer_bytes=max_buffer_bytes,
size_func=size_func,
name=name,
),
)
return Stream(self._source_factory, [*self._processes, process])
[docs]
def flatten(
self,
*,
workers: int = 1,
buffer_size: float = 0,
max_buffer_bytes: int = 10_000_000,
size_func: Callable[[Any], int] | None = None,
name: str | None = None,
) -> Stream[Any]:
"""
Flatten iterables in the stream.
Converts a ``Stream[AsyncIterable[U] | Iterable[U]]`` into a
``Stream[U]`` by yielding each sub-item individually — equivalent
to ``flat_map(identity)``.
Args:
workers: Concurrent workers for this process.
buffer_size: Backpressure buffer to downstream.
max_buffer_bytes: Memory-based buffer limit in bytes.
Defaults to 10MB (10_000_000 bytes).
size_func: Optional function to calculate item size in bytes.
name: Label for tracing.
"""
return self.flat_map(
lambda x: x,
workers=workers,
buffer_size=buffer_size,
max_buffer_bytes=max_buffer_bytes,
size_func=size_func,
name=name,
)
# ------------------------------------------------------------------
# Execution engine (structured concurrency)
# ------------------------------------------------------------------
@asynccontextmanager
async def _execute(self) -> AsyncIterator[MemoryObjectReceiveStream[T]]:
"""
Materialize the pipeline and yield the output receive stream.
Runs inside ``async with create_task_group()`` so every spawned
task either completes or is cancelled when the block exits.
When a process has max_buffer_bytes set, integrates MemoryBuffer:
upstream → MemoryObjectStream(∞) → MemoryBuffer → MemoryObjectStream(buffer_size) → process
Usage::
async with stream._execute() as recv:
async for item in recv:
...
"""
processes = self._processes
if not processes:
# No processes — connect source directly to output
send, recv = anyio.create_memory_object_stream[Any](math.inf)
async with anyio.create_task_group() as tg:
tg.start_soon(self._source_factory, send)
try:
yield recv
finally:
recv.close()
tg.cancel_scope.cancel()
return
# Build channel chain:
# source → [ch0] → process0 → [ch1] → process1 → ... → [chN] → output
#
# ch0 : between source and first process
# ch1..chN : between process[i-1] and process[i], then final output
#
# When max_buffer_bytes is set for a process:
# upstream → [unbounded] → MemoryBuffer → [bounded] → process
channels: list[
tuple[MemoryObjectSendStream[Any], MemoryObjectReceiveStream[Any]]
] = []
# Source → first process channel
# MemoryBuffer is always enabled, so source channel is unbounded
channels.append(anyio.create_memory_object_stream[Any](math.inf))
# Inter-process + final output channels
for i, process in enumerate(processes):
# MemoryBuffer is always enabled, so inter-process channels are unbounded
if i + 1 < len(processes):
channels.append(anyio.create_memory_object_stream[Any](math.inf))
else:
channels.append(
anyio.create_memory_object_stream[Any](process.config.buffer_size)
)
output_recv = channels[-1][1]
async with anyio.create_task_group() as tg:
# 1. Source producer
tg.start_soon(self._source_factory, channels[0][0])
# 2. Each process with MemoryBuffer integration
for i, process in enumerate(processes):
# Create MemoryBuffer layer
# upstream → [unbounded recv from channels[i]] → MemoryBuffer → [bounded] → process → [channels[i+1]]
memory_buffer = MemoryBuffer(
max_buffer_bytes=process.config.max_buffer_bytes,
size_func=process.config.size_func,
)
# Create bounded channel between MemoryBuffer and process
mb_send, mb_recv = anyio.create_memory_object_stream[Any](
process.config.buffer_size
)
# Start MemoryBuffer: reads from upstream, writes to mb_send
tg.start_soon(memory_buffer.run, channels[i][1], mb_send)
# Start process: reads from mb_recv, writes to downstream
tg.start_soon(process.run, mb_recv, channels[i + 1][0])
# 3. Yield the final output stream to the caller
try:
yield output_recv
finally:
# On exit (normal or early), close the output stream and
# cancel in-flight tasks so the task group can exit cleanly.
output_recv.close()
tg.cancel_scope.cancel()
# ------------------------------------------------------------------
# Terminal operations
# ------------------------------------------------------------------
[docs]
async def collect(self, *, batch: bool = False) -> list[T]:
"""
Execute the pipeline and collect all outputs into a list.
Args:
batch: If ``True``, drain any ``AsyncIterable`` or ``Iterable``
items into sub-lists, so that a ``Stream[AsyncIterator[U]]``
produces ``list[list[U]]`` instead of ``list[AsyncIterator[U]]``.
Returns:
All items produced by the final process.
"""
results: list[T] = []
async with self._execute() as recv:
async for item in recv:
if batch:
if isinstance(item, AsyncIterable):
results.append([sub async for sub in item]) # type: ignore[arg-type]
elif isinstance(item, Iterable) and not isinstance(
item, (str, bytes)
):
results.append(list(item)) # type: ignore[arg-type]
else:
results.append(item)
else:
results.append(item)
return results
[docs]
async def count(self) -> int:
"""
Execute the pipeline, discarding outputs.
Returns:
Number of items processed.
"""
count = 0
async with self._execute() as recv:
async for _item in recv:
count += 1
return count
[docs]
async def reduce(
self,
func: Callable[[U, T], U | Awaitable[U]],
initial: U,
) -> U:
"""
Fold all items into a single value.
Args:
func: Reducer ``(acc, item) -> acc``.
initial: Starting accumulator value.
Returns:
Final accumulated value.
"""
acc = initial
async with self._execute() as recv:
async for item in recv:
result = func(acc, item)
if isinstance(result, Awaitable):
acc = await result
else:
acc = result
return acc
[docs]
async def first(self) -> T | None:
"""Return the first item, or ``None`` if the stream is empty."""
async with self._execute() as recv:
async for item in recv:
return item
return None
[docs]
async def take(self, n: int) -> list[T]:
"""Collect at most *n* items."""
results: list[T] = []
async with self._execute() as recv:
async for item in recv:
results.append(item)
if len(results) >= n:
break
return results
[docs]
@asynccontextmanager
async def open(self) -> AsyncIterator[MemoryObjectReceiveStream[T]]:
"""
Context-manager API for manual iteration with structured concurrency.
Usage::
async with stream.open() as items:
async for item in items:
process(item)
"""
async with self._execute() as recv:
yield recv
# ------------------------------------------------------------------
# Pipe operator
# ------------------------------------------------------------------
def __or__(self, other: Any) -> Any:
"""
Support ``stream | pipe.map(fn)`` syntax.
``other`` is either:
- A callable (``_PipeOp``) that returns a new ``Stream``
- ``_COLLECT_SENTINEL`` from ``pipe.collect()``
- ``_COUNT_SENTINEL`` from ``pipe.count()``
"""
if other is _COLLECT_SENTINEL:
return self.collect()
if other is _COLLECT_BATCH_SENTINEL:
return self.collect(batch=True)
if other is _COUNT_SENTINEL:
return self.count()
if other is _COLLECT_SPLIT_SENTINEL:
return self.collect_split()
if callable(other):
return other(self)
return NotImplemented