import asyncio
from datetime import datetime
import time
import re
from typing import Optional, Union, Callable, List, Dict
import pandas as pd
from .dtypes import EmitEveryEvent, FixedWindows, SlidingWindows, EmitAfterPeriod, EmitAfterWindow, EmitAfterMaxEvent, \
_dict_to_emit_policy, FieldAggregator, EmitPolicy, FixedWindowType
from .table import Table
from .flow import Flow, _termination_obj, Event
from .utils import stringify_key
_default_emit_policy = EmitEveryEvent()
[docs]class AggregateByKey(Flow):
"""
Aggregates the data into the table object provided for later persistence, and outputs an event enriched with the requested aggregation
features.
Persistence is done via the `NoSqlTarget` step and based on the Cache object persistence settings.
:param aggregates: List of aggregates to apply for each event.
accepts either list of FieldAggregators or a dictionary describing FieldAggregators.
:param table: A Table object or name for persistence of aggregations.
If a table name is provided, it will be looked up in the context object passed in kwargs.
:param key: Key field to aggregate by, accepts either a string representing the key field or a key extracting function.
Defaults to the key in the event's metadata. (Optional)
:param emit_policy: Policy indicating when the data will be emitted. Defaults to EmitEveryEvent
:param augmentation_fn: Function that augments the features into the event's body. Defaults to updating a dict. (Optional)
:param enrich_with: List of attributes names from the associated storage object to be fetched and added to every event. (Optional)
:param aliases: Dictionary specifying aliases for enriched or aggregate columns, of the
format `{'col_name': 'new_col_name'}`. (Optional)
"""
def __init__(self, aggregates: Union[List[FieldAggregator], List[Dict[str, object]]], table: Union[Table, str],
key: Union[str, Callable[[Event], object], None] = None,
emit_policy: Union[EmitPolicy, Dict[str, object]] = _default_emit_policy,
augmentation_fn: Optional[Callable[[Event, Dict[str, object]], Event]] = None, enrich_with: Optional[List[str]] = None,
aliases: Optional[Dict[str, str]] = None, use_windows_from_schema: bool = False, **kwargs):
Flow.__init__(self, **kwargs)
aggregates = self._parse_aggregates(aggregates)
self._check_unique_names(aggregates)
self._table = table
if isinstance(table, str):
if not self.context:
raise TypeError("Table can not be string if no context was provided to the step")
self._table = self.context.get_table(table)
self._table._set_aggregation_metadata(aggregates, use_windows_from_schema=use_windows_from_schema)
self._closeables = [self._table]
self._aggregates_metadata = aggregates
self._enrich_with = enrich_with or []
self._aliases = aliases or {}
if not isinstance(emit_policy, EmitPolicy) and not isinstance(emit_policy, Dict):
raise TypeError(f'emit_policy parameter must be of type EmitPolicy, or dict. Found {type(emit_policy)} instead.')
self._emit_policy = emit_policy
if isinstance(self._emit_policy, dict):
self._emit_policy = _dict_to_emit_policy(self._emit_policy)
self._augmentation_fn = augmentation_fn
if not augmentation_fn:
def f(element, features):
features.update(element)
return features
self._augmentation_fn = f
self.key_extractor = None
if key:
if callable(key):
self.key_extractor = key
elif isinstance(key, str):
self.key_extractor = lambda element: element.get(key)
elif isinstance(key, list):
self.key_extractor = lambda element: [element.get(single_key) for single_key in key]
else:
raise TypeError(f'key is expected to be either a callable or string but got {type(key)}')
def _init(self):
super()._init()
self._events_in_batch = {}
self._emit_worker_running = False
self._terminate_worker = False
self._timeout_task: Optional[asyncio.Task] = None
def _check_unique_names(self, aggregates):
unique_aggr_names = set()
for aggr in aggregates:
if aggr.name in unique_aggr_names:
raise TypeError(f'Aggregates should have unique names. {aggr.name} already exists')
unique_aggr_names.add(aggr.name)
@staticmethod
def _parse_aggregates(aggregates):
if not isinstance(aggregates, list):
raise TypeError('aggregates should be a list of FieldAggregator/dictionaries')
if not aggregates or isinstance(aggregates[0], FieldAggregator):
return aggregates
if isinstance(aggregates[0], dict):
new_aggregates = []
for aggregate_dict in aggregates:
if 'period' in aggregate_dict:
window = SlidingWindows(aggregate_dict['windows'], aggregate_dict['period'])
else:
window = FixedWindows(aggregate_dict['windows'])
new_aggregates.append(FieldAggregator(aggregate_dict['name'], aggregate_dict['column'], aggregate_dict['operations'],
window, aggregate_dict.get('aggregation_filter', None),
aggregate_dict.get('max_value', None)))
return new_aggregates
raise TypeError('aggregates should be a list of FieldAggregator/dictionaries')
def _get_timestamp(self, event):
event_timestamp = event.time
if isinstance(event_timestamp, datetime):
if isinstance(event_timestamp, pd.Timestamp) and event_timestamp.tzinfo is None:
# timestamp for pandas timestamp gives the wrong result in case there is no timezone (ML-313)
local_time_zone = datetime.now().astimezone().tzinfo
event_timestamp = event_timestamp.replace(tzinfo=local_time_zone)
event_timestamp = event_timestamp.timestamp() * 1000
return event_timestamp
async def _do(self, event):
if event == _termination_obj:
self._terminate_worker = True
return await self._do_downstream(_termination_obj)
try:
# check whether a background loop is needed, if so create start one
if (not self._emit_worker_running) and \
(isinstance(self._emit_policy, EmitAfterPeriod) or isinstance(self._emit_policy, EmitAfterWindow)):
asyncio.get_running_loop().create_task(self._emit_worker())
self._emit_worker_running = True
element = event.body
key = event.key
if self.key_extractor:
key = self.key_extractor(element)
event_timestamp = self._get_timestamp(event)
safe_key = stringify_key(key)
await self._table._lazy_load_key_with_aggregates(safe_key, event_timestamp)
await self._table._aggregate(safe_key, element, event_timestamp)
if isinstance(self._emit_policy, EmitEveryEvent):
await self._emit_event(key, event)
elif isinstance(self._emit_policy, EmitAfterMaxEvent):
if safe_key in self._events_in_batch:
self._events_in_batch[safe_key]['counter'] += 1
else:
event_dict = {'counter': 1, 'time': time.monotonic()}
self._events_in_batch[safe_key] = event_dict
self._events_in_batch[safe_key]['event'] = event
if self._emit_policy.timeout_secs and self._timeout_task is None:
self._timeout_task = asyncio.get_running_loop().create_task(self._sleep_and_emit())
if self._events_in_batch[safe_key]['counter'] == self._emit_policy.max_events:
event_from_batch = self._events_in_batch.pop(safe_key, None)
if event_from_batch is not None:
await self._emit_event(key, event_from_batch['event'])
except Exception as ex:
raise ex
async def _sleep_and_emit(self):
while self._events_in_batch:
key = next(iter(self._events_in_batch.keys()))
delta_seconds = time.monotonic() - self._events_in_batch[key]['time']
if delta_seconds < self._emit_policy.timeout_secs:
await asyncio.sleep(self._emit_policy.timeout_secs - delta_seconds)
event = self._events_in_batch.pop(key, None)
if event is not None:
await self._emit_event(key, event['event'])
self._timeout_task = None
# Emit a single event for the requested key
async def _emit_event(self, key, event):
event_timestamp = self._get_timestamp(event)
safe_key = stringify_key(key)
await self._table._lazy_load_key_with_aggregates(safe_key, event_timestamp)
features = await self._table._get_features(safe_key, event_timestamp)
for feature_name in list(features.keys()):
if feature_name in self._aliases:
new_feature_name = self._aliases[feature_name]
if feature_name != new_feature_name:
features[new_feature_name] = features[feature_name]
del features[feature_name]
features = self._augmentation_fn(event.body, features)
for col in self._enrich_with:
emitted_attr_name = self._aliases.get(col, None) or col
if col in self._table._get_static_attrs(safe_key):
features[emitted_attr_name] = self._table._get_static_attrs(safe_key)[col]
event.key = key
event.body = features
await self._do_downstream(event)
# Emit multiple events for every key in the store with the current time
async def _emit_all_events(self, timestamp):
for key in self._table._get_keys():
await self._emit_event(key, Event({'key': key, 'time': timestamp}, key, timestamp, None))
async def _emit_worker(self):
if isinstance(self._emit_policy, EmitAfterPeriod):
seconds_to_sleep_between_emits = self._aggregates_metadata[0].windows.period_millis / 1000
elif isinstance(self._emit_policy, EmitAfterWindow):
seconds_to_sleep_between_emits = self._aggregates_metadata[0].windows.windows[0][0] / 1000
else:
raise TypeError(f'Emit policy "{type(self._emit_policy)}" is not supported')
current_time = datetime.now().timestamp()
next_emit_time = int(
current_time / seconds_to_sleep_between_emits) * seconds_to_sleep_between_emits + seconds_to_sleep_between_emits
while not self._terminate_worker:
current_time = datetime.now().timestamp()
next_sleep_interval = next_emit_time - current_time + self._emit_policy.delay_in_seconds
if next_sleep_interval > 0:
await asyncio.sleep(next_sleep_interval)
await self._emit_all_events(next_emit_time * 1000)
next_emit_time = next_emit_time + seconds_to_sleep_between_emits
class QueryByKey(AggregateByKey):
"""
Query features by name
:param features: List of features to get.
:param table: A Table object or name for persistence of aggregations.
If a table name is provided, it will be looked up in the context object passed in kwargs.
:param key: Key field to aggregate by, accepts either a string representing the key field or a key extracting function.
Defaults to the key in the event's metadata. (Optional). Can be list of keys
:param augmentation_fn: Function that augments the features into the event's body. Defaults to updating a dict. (Optional)
:param aliases: Dictionary specifying aliases for enriched or aggregate columns, of the
format `{'col_name': 'new_col_name'}`. (Optional)
:param options: Enum flags specifying query options. (Optional)
"""
def __init__(self, features: List[str], table: Union[Table, str], key: Union[str, List[str], Callable[[Event], object], None] = None,
augmentation_fn: Optional[Callable[[Event, Dict[str, object]], Event]] = None,
aliases: Optional[Dict[str, str]] = None,
fixed_window_type: Optional[FixedWindowType] = FixedWindowType.CurrentOpenWindow, **kwargs):
self._aggrs = []
self._enrich_cols = []
resolved_aggrs = {}
for feature in features:
if re.match(r".*_[a-z]+_[0-9]+[smhd]", feature):
name, window = feature.rsplit('_', 1)
if name in resolved_aggrs:
resolved_aggrs[name].append(window)
else:
resolved_aggrs[name] = [window]
else:
self._enrich_cols.append(feature)
for name, windows in resolved_aggrs.items():
feature, aggr = name.rsplit('_', 1)
# setting as SlidingWindow temporarily until actual window type will be read from schema
self._aggrs.append(FieldAggregator(name=feature, field=None, aggr=[aggr], windows=SlidingWindows(windows)))
if isinstance(table, Table):
other_table = table._clone() if table._aggregates is not None else table
else:
other_table = table # str - pass table string along with the context object
AggregateByKey.__init__(self, self._aggrs, other_table, key, augmentation_fn=augmentation_fn,
enrich_with=self._enrich_cols, aliases=aliases, use_windows_from_schema=True, **kwargs)
self._table._aggregations_read_only = True
self._table.fixed_window_type = fixed_window_type
async def _do(self, event):
if event == _termination_obj:
self._terminate_worker = True
return await self._do_downstream(_termination_obj)
element = event.body
key = event.key
if self.key_extractor:
if element:
key = self.key_extractor(element)
if key is None or key == [None] or element is None:
event.body = None
await self._do_downstream(event)
return
await self._emit_event(key, event)
def _check_unique_names(self, aggregates):
pass