# Copyright 2023 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import ast
import base64
import json
import typing
import warnings
from urllib.parse import ParseResult, urlparse, urlunparse
import pydantic
from mergedeep import merge
import mlrun
import mlrun.errors
from ..secrets import get_secret_or_env
[docs]class DatastoreProfile(pydantic.BaseModel):
type: str
name: str
_private_attributes: list = ()
[docs] class Config:
extra = pydantic.Extra.forbid
[docs] @pydantic.validator("name")
@classmethod
def lower_case(cls, v):
return v.lower()
[docs] @staticmethod
def generate_secret_key(profile_name: str, project: str):
secret_name_separator = "."
full_key = (
"datastore-profiles"
+ secret_name_separator
+ project
+ secret_name_separator
+ profile_name
)
return full_key
[docs] def secrets(self) -> dict:
return None
[docs] def url(self, subpath) -> str:
return None
[docs]class TemporaryClientDatastoreProfiles(metaclass=mlrun.utils.singleton.Singleton):
def __init__(self):
self._data = {} # Initialize the dictionary
[docs] def add(self, profile: DatastoreProfile):
self._data[profile.name] = profile
[docs] def get(self, key):
return self._data.get(key, None)
[docs] def remove(self, key):
self._data.pop(key, None)
[docs]class DatastoreProfileBasic(DatastoreProfile):
type: str = pydantic.Field("basic")
_private_attributes = "private"
public: str
private: typing.Optional[str] = None
[docs]class DatastoreProfileKafkaTarget(DatastoreProfile):
type: str = pydantic.Field("kafka_target")
_private_attributes = "kwargs_private"
bootstrap_servers: typing.Optional[str] = None
brokers: typing.Optional[str] = None
topic: str
kwargs_public: typing.Optional[dict]
kwargs_private: typing.Optional[dict]
def __init__(self, **kwargs):
super().__init__(**kwargs)
if not self.brokers and not self.bootstrap_servers:
raise mlrun.errors.MLRunInvalidArgumentError(
"DatastoreProfileKafkaTarget requires the 'brokers' field to be set"
)
if self.bootstrap_servers:
if self.brokers:
raise mlrun.errors.MLRunInvalidArgumentError(
"DatastoreProfileKafkaTarget cannot be created with both 'brokers' and 'bootstrap_servers'"
)
else:
self.brokers = self.bootstrap_servers
self.bootstrap_servers = None
warnings.warn(
"'bootstrap_servers' parameter is deprecated in 1.7.0 and will be removed in 1.9.0, "
"use 'brokers' instead.",
# TODO: Remove this in 1.9.0
FutureWarning,
)
[docs] def attributes(self):
attributes = {"brokers": self.brokers or self.bootstrap_servers}
if self.kwargs_public:
attributes = merge(attributes, self.kwargs_public)
if self.kwargs_private:
attributes = merge(attributes, self.kwargs_private)
return attributes
[docs]class DatastoreProfileKafkaSource(DatastoreProfile):
type: str = pydantic.Field("kafka_source")
_private_attributes = ("kwargs_private", "sasl_user", "sasl_pass")
brokers: typing.Union[str, list[str]]
topics: typing.Union[str, list[str]]
group: typing.Optional[str] = "serving"
initial_offset: typing.Optional[str] = "earliest"
partitions: typing.Optional[typing.Union[str, list[str]]]
sasl_user: typing.Optional[str]
sasl_pass: typing.Optional[str]
kwargs_public: typing.Optional[dict]
kwargs_private: typing.Optional[dict]
[docs] def attributes(self):
attributes = {}
if self.kwargs_public:
attributes = merge(attributes, self.kwargs_public)
if self.kwargs_private:
attributes = merge(attributes, self.kwargs_private)
topics = [self.topics] if isinstance(self.topics, str) else self.topics
brokers = [self.brokers] if isinstance(self.brokers, str) else self.brokers
attributes["brokers"] = brokers
attributes["topics"] = topics
attributes["group"] = self.group
attributes["initial_offset"] = self.initial_offset
if self.partitions is not None:
attributes["partitions"] = self.partitions
sasl = attributes.pop("sasl", {})
if self.sasl_user and self.sasl_pass:
sasl["enabled"] = True
sasl["user"] = self.sasl_user
sasl["password"] = self.sasl_pass
if sasl:
attributes["sasl"] = sasl
return attributes
[docs]class DatastoreProfileV3io(DatastoreProfile):
type: str = pydantic.Field("v3io")
v3io_access_key: typing.Optional[str] = None
_private_attributes = "v3io_access_key"
[docs] def url(self, subpath):
subpath = subpath.lstrip("/")
return f"v3io:///{subpath}"
[docs] def secrets(self) -> dict:
res = {}
if self.v3io_access_key:
res["V3IO_ACCESS_KEY"] = self.v3io_access_key
return res
[docs]class DatastoreProfileS3(DatastoreProfile):
type: str = pydantic.Field("s3")
_private_attributes = ("access_key_id", "secret_key")
endpoint_url: typing.Optional[str] = None
force_non_anonymous: typing.Optional[str] = None
profile_name: typing.Optional[str] = None
assume_role_arn: typing.Optional[str] = None
access_key_id: typing.Optional[str] = None
secret_key: typing.Optional[str] = None
bucket: typing.Optional[str] = None
[docs] @pydantic.validator("bucket")
@classmethod
def check_bucket(cls, v):
if not v:
warnings.warn(
"The 'bucket' attribute will be mandatory starting from version 1.9",
FutureWarning,
stacklevel=2,
)
return v
[docs] def secrets(self) -> dict:
res = {}
if self.access_key_id:
res["AWS_ACCESS_KEY_ID"] = self.access_key_id
if self.secret_key:
res["AWS_SECRET_ACCESS_KEY"] = self.secret_key
if self.endpoint_url:
res["S3_ENDPOINT_URL"] = self.endpoint_url
if self.force_non_anonymous:
res["S3_NON_ANONYMOUS"] = self.force_non_anonymous
if self.profile_name:
res["AWS_PROFILE"] = self.profile_name
if self.assume_role_arn:
res["MLRUN_AWS_ROLE_ARN"] = self.assume_role_arn
return res
[docs] def url(self, subpath):
# TODO: There is an inconsistency with DatastoreProfileGCS. In DatastoreProfileGCS,
# we assume that the subpath can begin without a '/' character,
# while here we assume it always starts with one.
if self.bucket:
return f"s3://{self.bucket}{subpath}"
else:
return f"s3:/{subpath}"
[docs]class DatastoreProfileRedis(DatastoreProfile):
type: str = pydantic.Field("redis")
_private_attributes = ("username", "password")
endpoint_url: str
username: typing.Optional[str] = None
password: typing.Optional[str] = None
[docs] def url_with_credentials(self):
parsed_url = urlparse(self.endpoint_url)
username = self.username
password = self.password
netloc = parsed_url.hostname
if username:
if password:
netloc = f"{username}:{password}@{parsed_url.hostname}"
else:
netloc = f"{username}@{parsed_url.hostname}"
if parsed_url.port:
netloc += f":{parsed_url.port}"
new_parsed_url = ParseResult(
scheme=parsed_url.scheme,
netloc=netloc,
path=parsed_url.path,
params=parsed_url.params,
query=parsed_url.query,
fragment=parsed_url.fragment,
)
return urlunparse(new_parsed_url)
[docs] def secrets(self) -> dict:
res = {}
if self.username:
res["REDIS_USER"] = self.username
if self.password:
res["REDIS_PASSWORD"] = self.password
return res
[docs] def url(self, subpath):
return self.endpoint_url + subpath
[docs]class DatastoreProfileDBFS(DatastoreProfile):
type: str = pydantic.Field("dbfs")
_private_attributes = ("token",)
endpoint_url: typing.Optional[str] = None # host
token: typing.Optional[str] = None
[docs] def url(self, subpath) -> str:
return f"dbfs://{subpath}"
[docs] def secrets(self) -> dict:
res = {}
if self.token:
res["DATABRICKS_TOKEN"] = self.token
if self.endpoint_url:
res["DATABRICKS_HOST"] = self.endpoint_url
return res
[docs]class DatastoreProfileGCS(DatastoreProfile):
type: str = pydantic.Field("gcs")
_private_attributes = ("gcp_credentials",)
credentials_path: typing.Optional[str] = None # path to file.
gcp_credentials: typing.Optional[typing.Union[str, dict]] = None
bucket: typing.Optional[str] = None
[docs] @pydantic.validator("bucket")
@classmethod
def check_bucket(cls, v):
if not v:
warnings.warn(
"The 'bucket' attribute will be mandatory starting from version 1.9",
FutureWarning,
stacklevel=2,
)
return v
[docs] @pydantic.validator("gcp_credentials", pre=True, always=True)
@classmethod
def convert_dict_to_json(cls, v):
if isinstance(v, dict):
return json.dumps(v)
return v
[docs] def url(self, subpath) -> str:
# TODO: but there's something wrong with the subpath being assumed to not start with a slash here,
# but the opposite assumption is made in S3.
if subpath.startswith("/"):
# in gcs the path after schema is starts with bucket, wherefore it should not start with "/".
subpath = subpath[1:]
if self.bucket:
return f"gcs://{self.bucket}/{subpath}"
else:
return f"gcs://{subpath}"
[docs] def secrets(self) -> dict:
res = {}
if self.credentials_path:
res["GOOGLE_APPLICATION_CREDENTIALS"] = self.credentials_path
if self.gcp_credentials:
res["GCP_CREDENTIALS"] = self.gcp_credentials
return res
[docs]class DatastoreProfileAzureBlob(DatastoreProfile):
type: str = pydantic.Field("az")
_private_attributes = (
"connection_string",
"account_key",
"client_secret",
"sas_token",
"credential",
)
connection_string: typing.Optional[str] = None
account_name: typing.Optional[str] = None
account_key: typing.Optional[str] = None
tenant_id: typing.Optional[str] = None
client_id: typing.Optional[str] = None
client_secret: typing.Optional[str] = None
sas_token: typing.Optional[str] = None
credential: typing.Optional[str] = None
container: typing.Optional[str] = None
[docs] @pydantic.validator("container")
@classmethod
def check_container(cls, v):
if not v:
warnings.warn(
"The 'container' attribute will be mandatory starting from version 1.9",
FutureWarning,
stacklevel=2,
)
return v
[docs] def url(self, subpath) -> str:
if subpath.startswith("/"):
# in azure the path after schema is starts with container, wherefore it should not start with "/".
subpath = subpath[1:]
if self.container:
return f"az://{self.container}/{subpath}"
else:
return f"az://{subpath}"
[docs] def secrets(self) -> dict:
res = {}
if self.connection_string:
res["connection_string"] = self.connection_string
if self.account_name:
res["account_name"] = self.account_name
if self.account_key:
res["account_key"] = self.account_key
if self.tenant_id:
res["tenant_id"] = self.tenant_id
if self.client_id:
res["client_id"] = self.client_id
if self.client_secret:
res["client_secret"] = self.client_secret
if self.sas_token:
res["sas_token"] = self.sas_token
if self.credential:
res["credential"] = self.credential
return res
[docs]class DatastoreProfileHdfs(DatastoreProfile):
type: str = pydantic.Field("hdfs")
_private_attributes = "token"
host: typing.Optional[str] = None
port: typing.Optional[int] = None
http_port: typing.Optional[int] = None
user: typing.Optional[str] = None
[docs] def secrets(self) -> dict:
res = {}
if self.host:
res["HDFS_HOST"] = self.host
if self.port:
res["HDFS_PORT"] = self.port
if self.port:
res["HDFS_HTTP_PORT"] = self.http_port
if self.user:
res["HDFS_USER"] = self.user
return res or None
[docs] def url(self, subpath):
return f"webhdfs://{self.host}:{self.http_port}{subpath}"
[docs]class DatastoreProfile2Json(pydantic.BaseModel):
@staticmethod
def _to_json(attributes):
# First, base64 encode the values
encoded_dict = {
k: base64.b64encode(str(v).encode()).decode() for k, v in attributes.items()
}
# Then, return the dictionary as a JSON string with no spaces
return json.dumps(encoded_dict).replace(" ", "")
[docs] @staticmethod
def get_json_public(profile: DatastoreProfile) -> str:
return DatastoreProfile2Json._to_json(
{
k: v
for k, v in profile.dict().items()
if str(k) not in profile._private_attributes
}
)
[docs] @staticmethod
def get_json_private(profile: DatastoreProfile) -> str:
return DatastoreProfile2Json._to_json(
{
k: v
for k, v in profile.dict().items()
if str(k) in profile._private_attributes
}
)
[docs] @staticmethod
def create_from_json(public_json: str, private_json: str = "{}"):
attributes = json.loads(public_json)
attributes_public = {
k: base64.b64decode(str(v).encode()).decode() for k, v in attributes.items()
}
attributes = json.loads(private_json)
attributes_private = {
k: base64.b64decode(str(v).encode()).decode() for k, v in attributes.items()
}
decoded_dict = merge(attributes_public, attributes_private)
def safe_literal_eval(value):
try:
return ast.literal_eval(value)
except (ValueError, SyntaxError):
return value
decoded_dict = {k: safe_literal_eval(v) for k, v in decoded_dict.items()}
datastore_type = decoded_dict.get("type")
ds_profile_factory = {
"v3io": DatastoreProfileV3io,
"s3": DatastoreProfileS3,
"redis": DatastoreProfileRedis,
"basic": DatastoreProfileBasic,
"kafka_target": DatastoreProfileKafkaTarget,
"kafka_source": DatastoreProfileKafkaSource,
"dbfs": DatastoreProfileDBFS,
"gcs": DatastoreProfileGCS,
"az": DatastoreProfileAzureBlob,
"hdfs": DatastoreProfileHdfs,
}
if datastore_type in ds_profile_factory:
return ds_profile_factory[datastore_type].parse_obj(decoded_dict)
else:
if datastore_type:
reason = f"unexpected type '{decoded_dict['type']}'"
else:
reason = "missing type"
raise mlrun.errors.MLRunInvalidArgumentError(
f"Datastore profile failed to create from json due to {reason}"
)
[docs]def datastore_profile_read(url, project_name="", secrets: dict = None):
parsed_url = urlparse(url)
if parsed_url.scheme.lower() != "ds":
raise mlrun.errors.MLRunInvalidArgumentError(
f"resource URL '{url}' cannot be read as a datastore profile because its scheme is not 'ds'"
)
profile_name = parsed_url.hostname
project_name = project_name or mlrun.mlconf.default_project
datastore = TemporaryClientDatastoreProfiles().get(profile_name)
if datastore:
return datastore
public_profile = mlrun.db.get_run_db().get_datastore_profile(
profile_name, project_name
)
# The mlrun.db.get_run_db().get_datastore_profile() function is capable of returning
# two distinct types of objects based on its execution context.
# If it operates from the client or within the pod (which is the common scenario),
# it yields an instance of `mlrun.datastore.DatastoreProfile`. Conversely,
# when executed on the server with a direct call to `sqldb`, it produces an instance of
# mlrun.common.schemas.DatastoreProfile.
# In the latter scenario, an extra conversion step is required to transform the object
# into mlrun.datastore.DatastoreProfile.
if isinstance(public_profile, mlrun.common.schemas.DatastoreProfile):
public_profile = DatastoreProfile2Json.create_from_json(
public_json=public_profile.object
)
project_ds_name_private = DatastoreProfile.generate_secret_key(
profile_name, project_name
)
private_body = get_secret_or_env(project_ds_name_private, secret_provider=secrets)
if not public_profile or not private_body:
raise mlrun.errors.MLRunInvalidArgumentError(
f"Unable to retrieve the datastore profile '{url}' from either the server or local environment. "
"Make sure the profile is registered correctly, or if running in a local environment, "
"use register_temporary_client_datastore_profile() to provide credentials locally."
)
datastore = DatastoreProfile2Json.create_from_json(
public_json=DatastoreProfile2Json.get_json_public(public_profile),
private_json=private_body,
)
return datastore
[docs]def register_temporary_client_datastore_profile(profile: DatastoreProfile):
"""Register the datastore profile.
This profile is temporary and remains valid only for the duration of the caller's session.
It's beneficial for testing purposes.
"""
TemporaryClientDatastoreProfiles().add(profile)
[docs]def remove_temporary_client_datastore_profile(profile_name: str):
TemporaryClientDatastoreProfiles().remove(profile_name)