# Copyright 2023 Iguazio
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
import ast
import base64
import json
import typing
import warnings
from urllib.parse import ParseResult, urlparse, urlunparse
import pydantic
from mergedeep import merge
import mlrun
import mlrun.errors
from ..secrets import get_secret_or_env
[docs]class DatastoreProfile(pydantic.BaseModel):
type: str
name: str
_private_attributes: list = ()
[docs] class Config:
extra = pydantic.Extra.forbid
[docs] @pydantic.validator("name")
def lower_case(cls, v):
return v.lower()
[docs] @staticmethod
def generate_secret_key(profile_name: str, project: str):
secret_name_separator = "."
full_key = (
+ secret_name_separator
+ project
+ secret_name_separator
+ profile_name
return full_key
[docs] def secrets(self) -> dict:
return None
[docs] def url(self, subpath) -> str:
return None
[docs]class TemporaryClientDatastoreProfiles(metaclass=mlrun.utils.singleton.Singleton):
def __init__(self):
self._data = {} # Initialize the dictionary
[docs] def add(self, profile: DatastoreProfile):
self._data[profile.name] = profile
[docs] def get(self, key):
return self._data.get(key, None)
[docs] def remove(self, key):
self._data.pop(key, None)
[docs]class DatastoreProfileBasic(DatastoreProfile):
type: str = pydantic.Field("basic")
_private_attributes = "private"
public: str
private: typing.Optional[str] = None
[docs]class DatastoreProfileKafkaTarget(DatastoreProfile):
type: str = pydantic.Field("kafka_target")
_private_attributes = "kwargs_private"
bootstrap_servers: typing.Optional[str] = None
brokers: typing.Optional[str] = None
topic: str
kwargs_public: typing.Optional[dict]
kwargs_private: typing.Optional[dict]
def __init__(self, **kwargs):
if not self.brokers and not self.bootstrap_servers:
raise mlrun.errors.MLRunInvalidArgumentError(
"DatastoreProfileKafkaTarget requires the 'brokers' field to be set"
if self.bootstrap_servers:
if self.brokers:
raise mlrun.errors.MLRunInvalidArgumentError(
"DatastoreProfileKafkaTarget cannot be created with both 'brokers' and 'bootstrap_servers'"
self.brokers = self.bootstrap_servers
self.bootstrap_servers = None
"'bootstrap_servers' parameter is deprecated in 1.7.0 and will be removed in 1.9.0, "
"use 'brokers' instead.",
# TODO: Remove this in 1.9.0
[docs] def attributes(self):
attributes = {"brokers": self.brokers or self.bootstrap_servers}
if self.kwargs_public:
attributes = merge(attributes, self.kwargs_public)
if self.kwargs_private:
attributes = merge(attributes, self.kwargs_private)
return attributes
[docs]class DatastoreProfileKafkaSource(DatastoreProfile):
type: str = pydantic.Field("kafka_source")
_private_attributes = ("kwargs_private", "sasl_user", "sasl_pass")
brokers: typing.Union[str, list[str]]
topics: typing.Union[str, list[str]]
group: typing.Optional[str] = "serving"
initial_offset: typing.Optional[str] = "earliest"
partitions: typing.Optional[typing.Union[str, list[str]]]
sasl_user: typing.Optional[str]
sasl_pass: typing.Optional[str]
kwargs_public: typing.Optional[dict]
kwargs_private: typing.Optional[dict]
[docs] def attributes(self):
attributes = {}
if self.kwargs_public:
attributes = merge(attributes, self.kwargs_public)
if self.kwargs_private:
attributes = merge(attributes, self.kwargs_private)
topics = [self.topics] if isinstance(self.topics, str) else self.topics
brokers = [self.brokers] if isinstance(self.brokers, str) else self.brokers
attributes["brokers"] = brokers
attributes["topics"] = topics
attributes["group"] = self.group
attributes["initial_offset"] = self.initial_offset
if self.partitions is not None:
attributes["partitions"] = self.partitions
sasl = attributes.pop("sasl", {})
if self.sasl_user and self.sasl_pass:
sasl["enabled"] = True
sasl["user"] = self.sasl_user
sasl["password"] = self.sasl_pass
if sasl:
attributes["sasl"] = sasl
return attributes
[docs]class DatastoreProfileV3io(DatastoreProfile):
type: str = pydantic.Field("v3io")
v3io_access_key: typing.Optional[str] = None
_private_attributes = "v3io_access_key"
[docs] def url(self, subpath):
subpath = subpath.lstrip("/")
return f"v3io:///{subpath}"
[docs] def secrets(self) -> dict:
res = {}
if self.v3io_access_key:
res["V3IO_ACCESS_KEY"] = self.v3io_access_key
return res
[docs]class DatastoreProfileS3(DatastoreProfile):
type: str = pydantic.Field("s3")
_private_attributes = ("access_key_id", "secret_key")
endpoint_url: typing.Optional[str] = None
force_non_anonymous: typing.Optional[str] = None
profile_name: typing.Optional[str] = None
assume_role_arn: typing.Optional[str] = None
access_key_id: typing.Optional[str] = None
secret_key: typing.Optional[str] = None
bucket: typing.Optional[str] = None
[docs] @pydantic.validator("bucket")
def check_bucket(cls, v):
if not v:
"The 'bucket' attribute will be mandatory starting from version 1.9",
return v
[docs] def secrets(self) -> dict:
res = {}
if self.access_key_id:
res["AWS_ACCESS_KEY_ID"] = self.access_key_id
if self.secret_key:
res["AWS_SECRET_ACCESS_KEY"] = self.secret_key
if self.endpoint_url:
res["S3_ENDPOINT_URL"] = self.endpoint_url
if self.force_non_anonymous:
res["S3_NON_ANONYMOUS"] = self.force_non_anonymous
if self.profile_name:
res["AWS_PROFILE"] = self.profile_name
if self.assume_role_arn:
res["MLRUN_AWS_ROLE_ARN"] = self.assume_role_arn
return res
[docs] def url(self, subpath):
# TODO: There is an inconsistency with DatastoreProfileGCS. In DatastoreProfileGCS,
# we assume that the subpath can begin without a '/' character,
# while here we assume it always starts with one.
if self.bucket:
return f"s3://{self.bucket}{subpath}"
return f"s3:/{subpath}"
[docs]class DatastoreProfileRedis(DatastoreProfile):
type: str = pydantic.Field("redis")
_private_attributes = ("username", "password")
endpoint_url: str
username: typing.Optional[str] = None
password: typing.Optional[str] = None
[docs] def url_with_credentials(self):
parsed_url = urlparse(self.endpoint_url)
username = self.username
password = self.password
netloc = parsed_url.hostname
if username:
if password:
netloc = f"{username}:{password}@{parsed_url.hostname}"
netloc = f"{username}@{parsed_url.hostname}"
if parsed_url.port:
netloc += f":{parsed_url.port}"
new_parsed_url = ParseResult(
return urlunparse(new_parsed_url)
[docs] def secrets(self) -> dict:
res = {}
if self.username:
res["REDIS_USER"] = self.username
if self.password:
res["REDIS_PASSWORD"] = self.password
return res
[docs] def url(self, subpath):
return self.endpoint_url + subpath
[docs]class DatastoreProfileDBFS(DatastoreProfile):
type: str = pydantic.Field("dbfs")
_private_attributes = ("token",)
endpoint_url: typing.Optional[str] = None # host
token: typing.Optional[str] = None
[docs] def url(self, subpath) -> str:
return f"dbfs://{subpath}"
[docs] def secrets(self) -> dict:
res = {}
if self.token:
res["DATABRICKS_TOKEN"] = self.token
if self.endpoint_url:
res["DATABRICKS_HOST"] = self.endpoint_url
return res
[docs]class DatastoreProfileGCS(DatastoreProfile):
type: str = pydantic.Field("gcs")
_private_attributes = ("gcp_credentials",)
credentials_path: typing.Optional[str] = None # path to file.
gcp_credentials: typing.Optional[typing.Union[str, dict]] = None
bucket: typing.Optional[str] = None
[docs] @pydantic.validator("bucket")
def check_bucket(cls, v):
if not v:
"The 'bucket' attribute will be mandatory starting from version 1.9",
return v
[docs] @pydantic.validator("gcp_credentials", pre=True, always=True)
def convert_dict_to_json(cls, v):
if isinstance(v, dict):
return json.dumps(v)
return v
[docs] def url(self, subpath) -> str:
# TODO: but there's something wrong with the subpath being assumed to not start with a slash here,
# but the opposite assumption is made in S3.
if subpath.startswith("/"):
# in gcs the path after schema is starts with bucket, wherefore it should not start with "/".
subpath = subpath[1:]
if self.bucket:
return f"gcs://{self.bucket}/{subpath}"
return f"gcs://{subpath}"
[docs] def secrets(self) -> dict:
res = {}
if self.credentials_path:
res["GOOGLE_APPLICATION_CREDENTIALS"] = self.credentials_path
if self.gcp_credentials:
res["GCP_CREDENTIALS"] = self.gcp_credentials
return res
[docs]class DatastoreProfileAzureBlob(DatastoreProfile):
type: str = pydantic.Field("az")
_private_attributes = (
connection_string: typing.Optional[str] = None
account_name: typing.Optional[str] = None
account_key: typing.Optional[str] = None
tenant_id: typing.Optional[str] = None
client_id: typing.Optional[str] = None
client_secret: typing.Optional[str] = None
sas_token: typing.Optional[str] = None
credential: typing.Optional[str] = None
container: typing.Optional[str] = None
[docs] @pydantic.validator("container")
def check_container(cls, v):
if not v:
"The 'container' attribute will be mandatory starting from version 1.9",
return v
[docs] def url(self, subpath) -> str:
if subpath.startswith("/"):
# in azure the path after schema is starts with container, wherefore it should not start with "/".
subpath = subpath[1:]
if self.container:
return f"az://{self.container}/{subpath}"
return f"az://{subpath}"
[docs] def secrets(self) -> dict:
res = {}
if self.connection_string:
res["connection_string"] = self.connection_string
if self.account_name:
res["account_name"] = self.account_name
if self.account_key:
res["account_key"] = self.account_key
if self.tenant_id:
res["tenant_id"] = self.tenant_id
if self.client_id:
res["client_id"] = self.client_id
if self.client_secret:
res["client_secret"] = self.client_secret
if self.sas_token:
res["sas_token"] = self.sas_token
if self.credential:
res["credential"] = self.credential
return res
[docs]class DatastoreProfileHdfs(DatastoreProfile):
type: str = pydantic.Field("hdfs")
_private_attributes = "token"
host: typing.Optional[str] = None
port: typing.Optional[int] = None
http_port: typing.Optional[int] = None
user: typing.Optional[str] = None
[docs] def secrets(self) -> dict:
res = {}
if self.host:
res["HDFS_HOST"] = self.host
if self.port:
res["HDFS_PORT"] = self.port
if self.port:
res["HDFS_HTTP_PORT"] = self.http_port
if self.user:
res["HDFS_USER"] = self.user
return res or None
[docs] def url(self, subpath):
return f"webhdfs://{self.host}:{self.http_port}{subpath}"
[docs]class DatastoreProfile2Json(pydantic.BaseModel):
def _to_json(attributes):
# First, base64 encode the values
encoded_dict = {
k: base64.b64encode(str(v).encode()).decode() for k, v in attributes.items()
# Then, return the dictionary as a JSON string with no spaces
return json.dumps(encoded_dict).replace(" ", "")
[docs] @staticmethod
def get_json_public(profile: DatastoreProfile) -> str:
return DatastoreProfile2Json._to_json(
k: v
for k, v in profile.dict().items()
if str(k) not in profile._private_attributes
[docs] @staticmethod
def get_json_private(profile: DatastoreProfile) -> str:
return DatastoreProfile2Json._to_json(
k: v
for k, v in profile.dict().items()
if str(k) in profile._private_attributes
[docs] @staticmethod
def create_from_json(public_json: str, private_json: str = "{}"):
attributes = json.loads(public_json)
attributes_public = {
k: base64.b64decode(str(v).encode()).decode() for k, v in attributes.items()
attributes = json.loads(private_json)
attributes_private = {
k: base64.b64decode(str(v).encode()).decode() for k, v in attributes.items()
decoded_dict = merge(attributes_public, attributes_private)
def safe_literal_eval(value):
return ast.literal_eval(value)
except (ValueError, SyntaxError):
return value
decoded_dict = {k: safe_literal_eval(v) for k, v in decoded_dict.items()}
datastore_type = decoded_dict.get("type")
ds_profile_factory = {
"v3io": DatastoreProfileV3io,
"s3": DatastoreProfileS3,
"redis": DatastoreProfileRedis,
"basic": DatastoreProfileBasic,
"kafka_target": DatastoreProfileKafkaTarget,
"kafka_source": DatastoreProfileKafkaSource,
"dbfs": DatastoreProfileDBFS,
"gcs": DatastoreProfileGCS,
"az": DatastoreProfileAzureBlob,
"hdfs": DatastoreProfileHdfs,
if datastore_type in ds_profile_factory:
return ds_profile_factory[datastore_type].parse_obj(decoded_dict)
if datastore_type:
reason = f"unexpected type '{decoded_dict['type']}'"
reason = "missing type"
raise mlrun.errors.MLRunInvalidArgumentError(
f"Datastore profile failed to create from json due to {reason}"
[docs]def datastore_profile_read(url, project_name="", secrets: dict = None):
parsed_url = urlparse(url)
if parsed_url.scheme.lower() != "ds":
raise mlrun.errors.MLRunInvalidArgumentError(
f"resource URL '{url}' cannot be read as a datastore profile because its scheme is not 'ds'"
profile_name = parsed_url.hostname
project_name = project_name or mlrun.mlconf.default_project
datastore = TemporaryClientDatastoreProfiles().get(profile_name)
if datastore:
return datastore
public_profile = mlrun.db.get_run_db().get_datastore_profile(
profile_name, project_name
# The mlrun.db.get_run_db().get_datastore_profile() function is capable of returning
# two distinct types of objects based on its execution context.
# If it operates from the client or within the pod (which is the common scenario),
# it yields an instance of `mlrun.datastore.DatastoreProfile`. Conversely,
# when executed on the server with a direct call to `sqldb`, it produces an instance of
# mlrun.common.schemas.DatastoreProfile.
# In the latter scenario, an extra conversion step is required to transform the object
# into mlrun.datastore.DatastoreProfile.
if isinstance(public_profile, mlrun.common.schemas.DatastoreProfile):
public_profile = DatastoreProfile2Json.create_from_json(
project_ds_name_private = DatastoreProfile.generate_secret_key(
profile_name, project_name
private_body = get_secret_or_env(project_ds_name_private, secret_provider=secrets)
if not public_profile or not private_body:
raise mlrun.errors.MLRunInvalidArgumentError(
f"Unable to retrieve the datastore profile '{url}' from either the server or local environment. "
"Make sure the profile is registered correctly, or if running in a local environment, "
"use register_temporary_client_datastore_profile() to provide credentials locally."
datastore = DatastoreProfile2Json.create_from_json(
return datastore
[docs]def register_temporary_client_datastore_profile(profile: DatastoreProfile):
"""Register the datastore profile.
This profile is temporary and remains valid only for the duration of the caller's session.
It's beneficial for testing purposes.
[docs]def remove_temporary_client_datastore_profile(profile_name: str):