Source code for storey.steps.collector

# Copyright 2026 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import copy
import datetime
from collections import defaultdict

from ..dtypes import StreamCompletion, _termination_obj
from ..flow import Flow


[docs] class Collector(Flow): """ Collects streaming chunks and emits a single event once all chunks for a stream are received. Acts as a no-op passthrough for non-streaming events. This step accumulates chunks from upstream streaming steps until a StreamCompletion sentinel is received. Once all expected completions are received, it emits a single event containing all collected chunk bodies as a list. :param expected_completions: The number of StreamCompletion signals expected for a given event stream. Useful when there are upstream splits that duplicate streaming events. Defaults to 1. :type expected_completions: int :param name: Name of this step, as it should appear in logs. Defaults to class name (Collector). :type name: string """ def __init__(self, expected_completions: int = 1, **kwargs): super().__init__(**kwargs) if expected_completions < 1: raise ValueError("expected_completions must be at least 1") self._expected_completions = expected_completions # Map from event id -> {"chunks": [], "completions": 0, "first_event": Event} self._collected_streams: dict[str, dict] = defaultdict( lambda: {"chunks": [], "completions": 0, "first_event": None} ) def _calculate_streaming_duration(self, event): """ Calculate total streaming duration and update event metadata with microsec. Uses the 'when' timestamp from the first chunk's metadata (set by ParallelExecution) to calculate total elapsed time from stream start to completion. Streaming is only supported with a single selected runnable, so metadata is always flat (top-level 'when' and 'microsec'), never nested under model names. """ if not hasattr(event, "_metadata") or not event._metadata: return when_str = event._metadata.get("when") if not when_str: return try: start_time = datetime.datetime.fromisoformat(when_str) now = datetime.datetime.now(tz=datetime.timezone.utc) event._metadata["microsec"] = int((now - start_time).total_seconds() * 1_000_000) except (ValueError, TypeError) as exc: if self.logger: self.logger.warning(f"Failed to calculate streaming duration from 'when' timestamp '{when_str}': {exc}") async def _do(self, event): if event is _termination_obj: return await self._do_downstream(_termination_obj) # Handle StreamCompletion sentinel if isinstance(event, StreamCompletion): stream_id = event.original_event.id stream_data = self._collected_streams[stream_id] stream_data["completions"] += 1 if stream_data["completions"] >= self._expected_completions: # Stream is complete - emit collected result # Use first_event if we have chunks, otherwise use original_event from completion (empty stream) base_event = stream_data["first_event"] or event.original_event if event.error: collected_body = {"error": event.error} else: collected_body = [chunk.body for chunk in stream_data["chunks"]] if len(collected_body) == 1: collected_body = collected_body[0] # Copy the original event to preserve all attributes (important for offset management) collected_event = copy.copy(base_event) collected_event.body = collected_body # Clear streaming attributes and mark as collected if hasattr(collected_event, "streaming_step"): del collected_event.streaming_step if hasattr(collected_event, "chunk_id"): del collected_event.chunk_id collected_event.stream_collected = True # Calculate total streaming duration (microsec) if timing metadata exists self._calculate_streaming_duration(collected_event) await self._do_downstream(collected_event) # Clean up del self._collected_streams[stream_id] return None # Check if this is a streaming chunk (has streaming_step attribute) if hasattr(event, "streaming_step"): stream_id = event.id stream_data = self._collected_streams[stream_id] if stream_data["first_event"] is None: stream_data["first_event"] = event stream_data["chunks"].append(event) return None else: # Non-streaming event - pass through directly return await self._do_downstream(event)