Source code for storey.aggregations

# Copyright 2020 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import re
import time
from datetime import datetime
from typing import Callable, Dict, List, Optional, Union

import pandas as pd

from .aggregation_utils import is_aggregation_name
from .dtypes import (
    EmitAfterMaxEvent,
    EmitAfterPeriod,
    EmitAfterWindow,
    EmitEveryEvent,
    EmitPolicy,
    FieldAggregator,
    FixedWindows,
    FixedWindowType,
    SlidingWindows,
    _dict_to_emit_policy,
)
from .flow import Event, Flow, _termination_obj
from .table import Table
from .utils import stringify_key

_default_emit_policy = EmitEveryEvent()


[docs]class AggregateByKey(Flow): """ Aggregates the data into the table object provided for later persistence, and outputs an event enriched with the requested aggregation features. Persistence is done via the `NoSqlTarget` step and based on the Cache object persistence settings. :param aggregates: List of aggregates to apply for each event. accepts either list of FieldAggregators or a dictionary describing FieldAggregators. :param table: A Table object or name for persistence of aggregations. If a table name is provided, it will be looked up in the context object passed in kwargs. :param key_field: Key field to aggregate by, accepts either a string representing the key field or a key extracting function. Defaults to the key in the event's metadata. (Optional) :param time_field: Time field to aggregate by, accepts either a string representing the time field or a time extracting function. Defaults to the processing time in the event's metadata. (Optional) :param emit_policy: Policy indicating when the data will be emitted. Defaults to EmitEveryEvent :param augmentation_fn: Function that augments the features into the event's body. Defaults to updating a dict. (Optional) :param enrich_with: List of attributes names from the associated storage object to be fetched and added to every event. (Optional) :param aliases: Dictionary specifying aliases for enriched or aggregate columns, of the format `{'col_name': 'new_col_name'}`. (Optional) :param time_format: If the value of the time field is of type string, this format will be used to parse it, as defined in datetime.strptime(). By default, parsing will follow ISO-8601. """ def __init__( self, aggregates: Union[List[FieldAggregator], List[Dict[str, object]]], table: Union[Table, str], key_field: Union[str, List[str], Callable[[Event], object], None] = None, time_field: Union[str, Callable[[Event], object], None] = None, emit_policy: Union[EmitPolicy, Dict[str, object]] = _default_emit_policy, augmentation_fn: Optional[Callable[[Event, Dict[str, object]], Event]] = None, enrich_with: Optional[List[str]] = None, aliases: Optional[Dict[str, str]] = None, use_windows_from_schema: bool = False, time_format: Optional[str] = None, **kwargs, ): Flow.__init__(self, **kwargs) aggregates = self._parse_aggregates(aggregates) self._check_unique_names(aggregates) self._table = table if isinstance(table, str): if not self.context: raise TypeError("Table can not be string if no context was provided to the step") self._table = self.context.get_table(table) self._table._set_aggregation_metadata(aggregates, use_windows_from_schema=use_windows_from_schema) self._closeables = [self._table] self._aggregates_metadata = aggregates self._enrich_with = enrich_with or [] self._aliases = aliases or {} if not isinstance(emit_policy, EmitPolicy) and not isinstance(emit_policy, Dict): raise TypeError( f"emit_policy parameter must be of type EmitPolicy, or dict. Found {type(emit_policy)} instead." ) self._emit_policy = emit_policy if isinstance(self._emit_policy, dict): self._emit_policy = _dict_to_emit_policy(self._emit_policy) self._augmentation_fn = augmentation_fn if not augmentation_fn: def f(element, features): features.update(element) return features self._augmentation_fn = f self._key_extractor = None if key_field: if callable(key_field): self._key_extractor = key_field elif isinstance(key_field, str): self._key_extractor = lambda event: event.get(key_field) elif isinstance(key_field, list): self._key_extractor = lambda event: [event.get(single_key) for single_key in key_field] else: raise TypeError( f"key_field is expected to be either a callable or string or list of strings " f"but got {type(key_field)}" ) self._time_extractor = lambda event: event.processing_time if time_field: if callable(time_field): self._time_extractor = time_field elif isinstance(time_field, str): self._time_extractor = lambda event: event.body.get(time_field) else: raise TypeError(f"time_field is expected to be either a callable or string but got {type(time_field)}") self._time_format = time_format def _init(self): super()._init() self._events_in_batch = {} self._emit_worker_running = False self._terminate_worker = False self._timeout_task: Optional[asyncio.Task] = None def _check_unique_names(self, aggregates): unique_aggr_names = set() for aggr in aggregates: if aggr.name in unique_aggr_names: raise TypeError(f"Aggregates should have unique names. {aggr.name} already exists") unique_aggr_names.add(aggr.name) @staticmethod def _parse_aggregates(aggregates): if not isinstance(aggregates, list): raise TypeError("aggregates should be a list of FieldAggregator/dictionaries") if not aggregates or isinstance(aggregates[0], FieldAggregator): return aggregates if isinstance(aggregates[0], dict): new_aggregates = [] for aggregate_dict in aggregates: if "period" in aggregate_dict: window = SlidingWindows(aggregate_dict["windows"], aggregate_dict["period"]) else: window = FixedWindows(aggregate_dict["windows"]) new_aggregates.append( FieldAggregator( aggregate_dict["name"], aggregate_dict["column"], aggregate_dict["operations"], window, aggregate_dict.get("aggregation_filter", None), aggregate_dict.get("max_value", None), ) ) return new_aggregates raise TypeError("aggregates should be a list of FieldAggregator/dictionaries") def _get_timestamp(self, event): event_timestamp = self._time_extractor(event) if isinstance(event_timestamp, str): if self._time_format: event_timestamp = datetime.strptime(event_timestamp, self._time_format) else: event_timestamp = datetime.fromisoformat(event_timestamp) if isinstance(event_timestamp, datetime): if isinstance(event_timestamp, pd.Timestamp) and event_timestamp.tzinfo is None: # timestamp for pandas timestamp gives the wrong result in case there is no timezone (ML-313) local_time_zone = datetime.now().astimezone().tzinfo event_timestamp = event_timestamp.replace(tzinfo=local_time_zone) event_timestamp = event_timestamp.timestamp() * 1000 return event_timestamp async def _do(self, event): if event == _termination_obj: self._terminate_worker = True return await self._do_downstream(_termination_obj) try: # check whether a background loop is needed, if so create start one if (not self._emit_worker_running) and ( isinstance(self._emit_policy, EmitAfterPeriod) or isinstance(self._emit_policy, EmitAfterWindow) ): asyncio.get_running_loop().create_task(self._emit_worker()) self._emit_worker_running = True element = event.body key = event.key if self._key_extractor: key = self._key_extractor(element) event_timestamp = self._get_timestamp(event) safe_key = stringify_key(key) await self._table._lazy_load_key_with_aggregates(safe_key, event_timestamp) await self._table._aggregate(safe_key, event, element, event_timestamp) if isinstance(self._emit_policy, EmitEveryEvent): await self._emit_event(key, event) elif isinstance(self._emit_policy, EmitAfterMaxEvent): if safe_key in self._events_in_batch: self._events_in_batch[safe_key]["counter"] += 1 else: event_dict = {"counter": 1, "time": time.monotonic()} self._events_in_batch[safe_key] = event_dict self._events_in_batch[safe_key]["event"] = event if self._emit_policy.timeout_secs and self._timeout_task is None: self._timeout_task = asyncio.get_running_loop().create_task(self._sleep_and_emit()) if self._events_in_batch[safe_key]["counter"] == self._emit_policy.max_events: event_from_batch = self._events_in_batch.pop(safe_key, None) if event_from_batch is not None: await self._emit_event(key, event_from_batch["event"]) except Exception as ex: raise ex async def _sleep_and_emit(self): while self._events_in_batch: key = next(iter(self._events_in_batch.keys())) delta_seconds = time.monotonic() - self._events_in_batch[key]["time"] if delta_seconds < self._emit_policy.timeout_secs: await asyncio.sleep(self._emit_policy.timeout_secs - delta_seconds) event = self._events_in_batch.pop(key, None) if event is not None: await self._emit_event(key, event["event"]) self._timeout_task = None # Emit a single event for the requested key async def _emit_event(self, key, event): event_timestamp = self._get_timestamp(event) safe_key = stringify_key(key) await self._table._lazy_load_key_with_aggregates(safe_key, event_timestamp) features = await self._table._get_features(safe_key, event_timestamp) for feature_name in list(features.keys()): if feature_name in self._aliases: new_feature_name = self._aliases[feature_name] if feature_name != new_feature_name: features[new_feature_name] = features[feature_name] del features[feature_name] features = self._augmentation_fn(event.body, features) for col in self._enrich_with: emitted_attr_name = self._aliases.get(col, None) or col if col in self._table._get_static_attrs(safe_key): features[emitted_attr_name] = self._table._get_static_attrs(safe_key)[col] event.key = key event.body = features await self._do_downstream(event) # Emit multiple events for every key in the store with the current time async def _emit_all_events(self, timestamp): for key in self._table._get_keys(): await self._emit_event(key, Event({"key": key, "time": timestamp}, key, timestamp, None)) async def _emit_worker(self): if isinstance(self._emit_policy, EmitAfterPeriod): seconds_to_sleep_between_emits = self._aggregates_metadata[0].windows.period_millis / 1000 elif isinstance(self._emit_policy, EmitAfterWindow): seconds_to_sleep_between_emits = self._aggregates_metadata[0].windows.windows[0][0] / 1000 else: raise TypeError(f'Emit policy "{type(self._emit_policy)}" is not supported') current_time = datetime.now().timestamp() next_emit_time = ( int(current_time / seconds_to_sleep_between_emits) * seconds_to_sleep_between_emits + seconds_to_sleep_between_emits ) while not self._terminate_worker: current_time = datetime.now().timestamp() next_sleep_interval = next_emit_time - current_time + self._emit_policy.delay_in_seconds if next_sleep_interval > 0: await asyncio.sleep(next_sleep_interval) await self._emit_all_events(next_emit_time * 1000) next_emit_time = next_emit_time + seconds_to_sleep_between_emits
class QueryByKey(AggregateByKey): """ Query features by name :param features: List of features to get. :param table: A Table object or name for persistence of aggregations. If a table name is provided, it will be looked up in the context object passed in kwargs. :param key_field: Key field to query by, accepts either a string representing the key field or a key extracting function. Defaults to the key in the event's metadata. Can be list of keys (Optional) :param time_field: Time field to query by, accepts either a string representing the time field or a time extracting function. Defaults to the processing time in the event's metadata. (Optional) :param augmentation_fn: Function that augments the features into the event's body. Defaults to updating a dict. (Optional) :param aliases: Dictionary specifying aliases for enriched or aggregate columns, of the format `{'col_name': 'new_col_name'}`. (Optional) :param options: Enum flags specifying query options. (Optional) """ def __init__( self, features: List[str], table: Union[Table, str], key_field: Union[str, List[str], Callable[[Event], object], None] = None, time_field: Union[str, List[str], Callable[[Event], object], None] = None, augmentation_fn: Optional[Callable[[Event, Dict[str, object]], Event]] = None, aliases: Optional[Dict[str, str]] = None, fixed_window_type: Optional[FixedWindowType] = FixedWindowType.CurrentOpenWindow, **kwargs, ): self._aggrs = [] self._enrich_cols = [] resolved_aggrs = {} if isinstance(table, str): if "context" not in kwargs: raise TypeError("Table can not be string if no context was provided to the step") table = kwargs["context"].get_table(table) for feature in features: match = re.match(r".*_([a-z]+)_[0-9]+[smhd]$", feature) if table.supports_aggregations() else None if match and is_aggregation_name(match.group(1)): name, window = feature.rsplit("_", 1) if name in resolved_aggrs: resolved_aggrs[name].append(window) else: resolved_aggrs[name] = [window] else: self._enrich_cols.append(feature) for name, windows in resolved_aggrs.items(): feature, aggr = name.rsplit("_", 1) # setting as SlidingWindow temporarily until actual window type will be read from schema self._aggrs.append( FieldAggregator( name=feature, field=None, aggr=[aggr], windows=SlidingWindows(windows), ) ) other_table = table._clone() if table._aggregates is not None else table AggregateByKey.__init__( self, self._aggrs, other_table, key_field, time_field, augmentation_fn=augmentation_fn, enrich_with=self._enrich_cols, aliases=aliases, use_windows_from_schema=True, **kwargs, ) self._table._aggregations_read_only = True self._table.fixed_window_type = fixed_window_type async def _do(self, event): if event == _termination_obj: self._terminate_worker = True return await self._do_downstream(_termination_obj) element = event.body key = event.key if self._key_extractor: if element: key = self._key_extractor(element) if key is None or key == [None] or element is None: event.body = None await self._do_downstream(event) return await self._emit_event(key, event) def _check_unique_names(self, aggregates): pass