# Copyright 2020 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import re
import time
from datetime import datetime
from typing import Callable, Dict, List, Optional, Union
import pandas as pd
from .aggregation_utils import is_aggregation_name
from .dtypes import (
EmitAfterMaxEvent,
EmitAfterPeriod,
EmitAfterWindow,
EmitEveryEvent,
EmitPolicy,
FieldAggregator,
FixedWindows,
FixedWindowType,
SlidingWindows,
_dict_to_emit_policy,
)
from .flow import Event, Flow, _termination_obj
from .table import Table
from .utils import stringify_key
_default_emit_policy = EmitEveryEvent()
[docs]class AggregateByKey(Flow):
"""
Aggregates the data into the table object provided for later persistence,
and outputs an event enriched with the requested aggregation features.
Persistence is done via the `NoSqlTarget` step and based on the Cache object persistence settings.
:param aggregates: List of aggregates to apply for each event.
accepts either list of FieldAggregators or a dictionary describing FieldAggregators.
:param table: A Table object or name for persistence of aggregations.
If a table name is provided, it will be looked up in the context object passed in kwargs.
:param key_field: Key field to aggregate by, accepts either a string representing the key field or a key
extracting function. Defaults to the key in the event's metadata. (Optional)
:param time_field: Time field to aggregate by, accepts either a string representing the time field or a time
extracting function. Defaults to the processing time in the event's metadata. (Optional)
:param emit_policy: Policy indicating when the data will be emitted. Defaults to EmitEveryEvent
:param augmentation_fn: Function that augments the features into the event's body. Defaults to
updating a dict. (Optional)
:param enrich_with: List of attributes names from the associated storage object to be fetched
and added to every event. (Optional)
:param aliases: Dictionary specifying aliases for enriched or aggregate columns, of the
format `{'col_name': 'new_col_name'}`. (Optional)
:param time_format: If the value of the time field is of type string, this format will be used to parse it, as
defined in datetime.strptime(). By default, parsing will follow ISO-8601.
"""
def __init__(
self,
aggregates: Union[List[FieldAggregator], List[Dict[str, object]]],
table: Union[Table, str],
key_field: Union[str, List[str], Callable[[Event], object], None] = None,
time_field: Union[str, Callable[[Event], object], None] = None,
emit_policy: Union[EmitPolicy, Dict[str, object]] = _default_emit_policy,
augmentation_fn: Optional[Callable[[Event, Dict[str, object]], Event]] = None,
enrich_with: Optional[List[str]] = None,
aliases: Optional[Dict[str, str]] = None,
use_windows_from_schema: bool = False,
time_format: Optional[str] = None,
**kwargs,
):
Flow.__init__(self, **kwargs)
aggregates = self._parse_aggregates(aggregates)
self._check_unique_names(aggregates)
self._table = table
if isinstance(table, str):
if not self.context:
raise TypeError("Table can not be string if no context was provided to the step")
self._table = self.context.get_table(table)
self._table._set_aggregation_metadata(aggregates, use_windows_from_schema=use_windows_from_schema)
self._closeables = [self._table]
self._aggregates_metadata = aggregates
self._enrich_with = enrich_with or []
self._aliases = aliases or {}
if not isinstance(emit_policy, EmitPolicy) and not isinstance(emit_policy, Dict):
raise TypeError(
f"emit_policy parameter must be of type EmitPolicy, or dict. Found {type(emit_policy)} instead."
)
self._emit_policy = emit_policy
if isinstance(self._emit_policy, dict):
self._emit_policy = _dict_to_emit_policy(self._emit_policy)
self._augmentation_fn = augmentation_fn
if not augmentation_fn:
def f(element, features):
features.update(element)
return features
self._augmentation_fn = f
self._key_extractor = None
if key_field:
if callable(key_field):
self._key_extractor = key_field
elif isinstance(key_field, str):
self._key_extractor = lambda event: event.get(key_field)
elif isinstance(key_field, list):
self._key_extractor = lambda event: [event.get(single_key) for single_key in key_field]
else:
raise TypeError(
f"key_field is expected to be either a callable or string or list of strings "
f"but got {type(key_field)}"
)
self._time_extractor = lambda event: event.processing_time
if time_field:
if callable(time_field):
self._time_extractor = time_field
elif isinstance(time_field, str):
self._time_extractor = lambda event: event.body.get(time_field)
else:
raise TypeError(f"time_field is expected to be either a callable or string but got {type(time_field)}")
self._time_format = time_format
def _init(self):
super()._init()
self._events_in_batch = {}
self._emit_worker_running = False
self._terminate_worker = False
self._timeout_task: Optional[asyncio.Task] = None
def _check_unique_names(self, aggregates):
unique_aggr_names = set()
for aggr in aggregates:
if aggr.name in unique_aggr_names:
raise TypeError(f"Aggregates should have unique names. {aggr.name} already exists")
unique_aggr_names.add(aggr.name)
@staticmethod
def _parse_aggregates(aggregates):
if not isinstance(aggregates, list):
raise TypeError("aggregates should be a list of FieldAggregator/dictionaries")
if not aggregates or isinstance(aggregates[0], FieldAggregator):
return aggregates
if isinstance(aggregates[0], dict):
new_aggregates = []
for aggregate_dict in aggregates:
if "period" in aggregate_dict:
window = SlidingWindows(aggregate_dict["windows"], aggregate_dict["period"])
else:
window = FixedWindows(aggregate_dict["windows"])
new_aggregates.append(
FieldAggregator(
aggregate_dict["name"],
aggregate_dict["column"],
aggregate_dict["operations"],
window,
aggregate_dict.get("aggregation_filter", None),
aggregate_dict.get("max_value", None),
)
)
return new_aggregates
raise TypeError("aggregates should be a list of FieldAggregator/dictionaries")
def _get_timestamp(self, event):
event_timestamp = self._time_extractor(event)
if isinstance(event_timestamp, str):
if self._time_format:
event_timestamp = datetime.strptime(event_timestamp, self._time_format)
else:
event_timestamp = datetime.fromisoformat(event_timestamp)
if isinstance(event_timestamp, datetime):
if isinstance(event_timestamp, pd.Timestamp) and event_timestamp.tzinfo is None:
# timestamp for pandas timestamp gives the wrong result in case there is no timezone (ML-313)
local_time_zone = datetime.now().astimezone().tzinfo
event_timestamp = event_timestamp.replace(tzinfo=local_time_zone)
event_timestamp = event_timestamp.timestamp() * 1000
return event_timestamp
async def _do(self, event):
if event == _termination_obj:
self._terminate_worker = True
return await self._do_downstream(_termination_obj)
try:
# check whether a background loop is needed, if so create start one
if (not self._emit_worker_running) and (
isinstance(self._emit_policy, EmitAfterPeriod) or isinstance(self._emit_policy, EmitAfterWindow)
):
asyncio.get_running_loop().create_task(self._emit_worker())
self._emit_worker_running = True
element = event.body
key = event.key
if self._key_extractor:
key = self._key_extractor(element)
event_timestamp = self._get_timestamp(event)
safe_key = stringify_key(key)
await self._table._lazy_load_key_with_aggregates(safe_key, event_timestamp)
await self._table._aggregate(safe_key, event, element, event_timestamp)
if isinstance(self._emit_policy, EmitEveryEvent):
await self._emit_event(key, event)
elif isinstance(self._emit_policy, EmitAfterMaxEvent):
if safe_key in self._events_in_batch:
self._events_in_batch[safe_key]["counter"] += 1
else:
event_dict = {"counter": 1, "time": time.monotonic()}
self._events_in_batch[safe_key] = event_dict
self._events_in_batch[safe_key]["event"] = event
if self._emit_policy.timeout_secs and self._timeout_task is None:
self._timeout_task = asyncio.get_running_loop().create_task(self._sleep_and_emit())
if self._events_in_batch[safe_key]["counter"] == self._emit_policy.max_events:
event_from_batch = self._events_in_batch.pop(safe_key, None)
if event_from_batch is not None:
await self._emit_event(key, event_from_batch["event"])
except Exception as ex:
raise ex
async def _sleep_and_emit(self):
while self._events_in_batch:
key = next(iter(self._events_in_batch.keys()))
delta_seconds = time.monotonic() - self._events_in_batch[key]["time"]
if delta_seconds < self._emit_policy.timeout_secs:
await asyncio.sleep(self._emit_policy.timeout_secs - delta_seconds)
event = self._events_in_batch.pop(key, None)
if event is not None:
await self._emit_event(key, event["event"])
self._timeout_task = None
# Emit a single event for the requested key
async def _emit_event(self, key, event):
event_timestamp = self._get_timestamp(event)
safe_key = stringify_key(key)
await self._table._lazy_load_key_with_aggregates(safe_key, event_timestamp)
features = await self._table._get_features(safe_key, event_timestamp)
for feature_name in list(features.keys()):
if feature_name in self._aliases:
new_feature_name = self._aliases[feature_name]
if feature_name != new_feature_name:
features[new_feature_name] = features[feature_name]
del features[feature_name]
features = self._augmentation_fn(event.body, features)
for col in self._enrich_with:
emitted_attr_name = self._aliases.get(col, None) or col
if col in self._table._get_static_attrs(safe_key):
features[emitted_attr_name] = self._table._get_static_attrs(safe_key)[col]
event.key = key
event.body = features
await self._do_downstream(event)
# Emit multiple events for every key in the store with the current time
async def _emit_all_events(self, timestamp):
for key in self._table._get_keys():
await self._emit_event(key, Event({"key": key, "time": timestamp}, key, timestamp, None))
async def _emit_worker(self):
if isinstance(self._emit_policy, EmitAfterPeriod):
seconds_to_sleep_between_emits = self._aggregates_metadata[0].windows.period_millis / 1000
elif isinstance(self._emit_policy, EmitAfterWindow):
seconds_to_sleep_between_emits = self._aggregates_metadata[0].windows.windows[0][0] / 1000
else:
raise TypeError(f'Emit policy "{type(self._emit_policy)}" is not supported')
current_time = datetime.now().timestamp()
next_emit_time = (
int(current_time / seconds_to_sleep_between_emits) * seconds_to_sleep_between_emits
+ seconds_to_sleep_between_emits
)
while not self._terminate_worker:
current_time = datetime.now().timestamp()
next_sleep_interval = next_emit_time - current_time + self._emit_policy.delay_in_seconds
if next_sleep_interval > 0:
await asyncio.sleep(next_sleep_interval)
await self._emit_all_events(next_emit_time * 1000)
next_emit_time = next_emit_time + seconds_to_sleep_between_emits
class QueryByKey(AggregateByKey):
"""
Query features by name
:param features: List of features to get.
:param table: A Table object or name for persistence of aggregations.
If a table name is provided, it will be looked up in the context object passed in kwargs.
:param key_field: Key field to query by, accepts either a string representing the key field or a key
extracting function. Defaults to the key in the event's metadata. Can be list of keys (Optional)
:param time_field: Time field to query by, accepts either a string representing the time field or a time
extracting function. Defaults to the processing time in the event's metadata. (Optional)
:param augmentation_fn: Function that augments the features into the event's body.
Defaults to updating a dict. (Optional)
:param aliases: Dictionary specifying aliases for enriched or aggregate columns, of the
format `{'col_name': 'new_col_name'}`. (Optional)
:param options: Enum flags specifying query options. (Optional)
"""
def __init__(
self,
features: List[str],
table: Union[Table, str],
key_field: Union[str, List[str], Callable[[Event], object], None] = None,
time_field: Union[str, List[str], Callable[[Event], object], None] = None,
augmentation_fn: Optional[Callable[[Event, Dict[str, object]], Event]] = None,
aliases: Optional[Dict[str, str]] = None,
fixed_window_type: Optional[FixedWindowType] = FixedWindowType.CurrentOpenWindow,
**kwargs,
):
self._aggrs = []
self._enrich_cols = []
resolved_aggrs = {}
if isinstance(table, str):
if "context" not in kwargs:
raise TypeError("Table can not be string if no context was provided to the step")
table = kwargs["context"].get_table(table)
for feature in features:
match = re.match(r".*_([a-z]+)_[0-9]+[smhd]$", feature) if table.supports_aggregations() else None
if match and is_aggregation_name(match.group(1)):
name, window = feature.rsplit("_", 1)
if name in resolved_aggrs:
resolved_aggrs[name].append(window)
else:
resolved_aggrs[name] = [window]
else:
self._enrich_cols.append(feature)
for name, windows in resolved_aggrs.items():
feature, aggr = name.rsplit("_", 1)
# setting as SlidingWindow temporarily until actual window type will be read from schema
self._aggrs.append(
FieldAggregator(
name=feature,
field=None,
aggr=[aggr],
windows=SlidingWindows(windows),
)
)
other_table = table._clone() if table._aggregates is not None else table
AggregateByKey.__init__(
self,
self._aggrs,
other_table,
key_field,
time_field,
augmentation_fn=augmentation_fn,
enrich_with=self._enrich_cols,
aliases=aliases,
use_windows_from_schema=True,
**kwargs,
)
self._table._aggregations_read_only = True
self._table.fixed_window_type = fixed_window_type
async def _do(self, event):
if event == _termination_obj:
self._terminate_worker = True
return await self._do_downstream(_termination_obj)
element = event.body
key = event.key
if self._key_extractor:
if element:
key = self._key_extractor(element)
if key is None or key == [None] or element is None:
event.body = None
await self._do_downstream(event)
return
await self._emit_event(key, event)
def _check_unique_names(self, aggregates):
pass