# Copyright 2023 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ast
import base64
import json
import typing
from urllib.parse import ParseResult, quote, unquote, urlparse
import pydantic.v1
from deprecated import deprecated
from mergedeep import merge
import mlrun
import mlrun.errors
from ..secrets import get_secret_or_env
[docs]
class DatastoreProfile(pydantic.v1.BaseModel):
type: str
name: str
_private_attributes: list = ()
[docs]
class Config:
extra = pydantic.v1.Extra.forbid
[docs]
@pydantic.v1.validator("name")
@classmethod
def lower_case(cls, v):
return v.lower()
[docs]
@staticmethod
def generate_secret_key(profile_name: str, project: str):
secret_name_separator = "."
full_key = (
"datastore-profiles"
+ secret_name_separator
+ project
+ secret_name_separator
+ profile_name
)
return full_key
[docs]
def secrets(self) -> dict:
return None
[docs]
def url(self, subpath) -> str:
return None
[docs]
class TemporaryClientDatastoreProfiles(metaclass=mlrun.utils.singleton.Singleton):
def __init__(self):
self._data = {} # Initialize the dictionary
[docs]
def add(self, profile: DatastoreProfile):
self._data[profile.name] = profile
[docs]
def get(self, key):
return self._data.get(key, None)
[docs]
def remove(self, key):
self._data.pop(key, None)
[docs]
class DatastoreProfileBasic(DatastoreProfile):
type: str = pydantic.v1.Field("basic")
_private_attributes = "private"
public: str
private: str | None = None
[docs]
class ConfigProfile(DatastoreProfile):
"""
A profile class for managing configuration data with nested public and private attributes.
This class extends DatastoreProfile to handle configuration settings, separating them into
public and private dictionaries. Both dictionaries support nested structures, and the class
provides functionality to merge these attributes when needed.
Args:
public (Optional[dict]): Dictionary containing public configuration settings,
supporting nested structures
private (Optional[dict]): Dictionary containing private/sensitive configuration settings,
supporting nested structures
Example:
>>> public = {
"database": {
"host": "localhost",
"port": 5432
},
"api_version": "v1"
}
>>> private = {
"database": {
"password": "secret123",
"username": "admin"
},
"api_key": "xyz789"
}
>>> config = ConfigProfile("myconfig", public=public, private=private)
# When attributes() is called, it merges public and private:
# {
# "database": {
# "host": "localhost",
# "port": 5432,
# "password": "secret123",
# "username": "admin"
# },
# "api_version": "v1",
# "api_key": "xyz789"
# }
"""
type = "config"
_private_attributes = "private"
public: dict | None = None
private: dict | None = None
[docs]
def attributes(self):
res = {}
if self.public:
res = merge(res, self.public)
if self.private:
res = merge(res, self.private)
return res
# TODO: Remove in 1.12.0
[docs]
@deprecated(
version="1.10.0",
reason=(
"This class is deprecated from mlrun 1.10.0, and will be removed in 1.12.0. "
"Use `DatastoreProfileKafkaStream` instead."
),
category=FutureWarning,
)
class DatastoreProfileKafkaTarget(DatastoreProfile):
type: str = pydantic.v1.Field("kafka_target")
_private_attributes = "kwargs_private"
brokers: str
topic: str
kwargs_public: dict | None
kwargs_private: dict | None
[docs]
def get_topic(self) -> str | None:
return self.topic
[docs]
def attributes(self):
attributes = {"brokers": self.brokers}
if self.kwargs_public:
attributes = merge(attributes, self.kwargs_public)
if self.kwargs_private:
attributes = merge(attributes, self.kwargs_private)
return attributes
[docs]
class DatastoreProfileKafkaStream(DatastoreProfile):
type: str = pydantic.v1.Field("kafka_stream")
_private_attributes = ("kwargs_private", "sasl_user", "sasl_pass")
brokers: typing.Union[str, list[str]]
topics: typing.Union[str, list[str]]
group: str | None = "serving"
initial_offset: str | None = "earliest"
partitions: typing.Union[str, list[str]] | None
sasl_user: str | None
sasl_pass: str | None
kwargs_public: dict | None
kwargs_private: dict | None
[docs]
def get_topic(self) -> str | None:
topics = [self.topics] if isinstance(self.topics, str) else self.topics
return topics[0] if topics else None
[docs]
def attributes(self) -> dict[str, typing.Any]:
attributes = {}
if self.kwargs_public:
attributes = merge(attributes, self.kwargs_public)
if self.kwargs_private:
attributes = merge(attributes, self.kwargs_private)
topics = [self.topics] if isinstance(self.topics, str) else self.topics
brokers = [self.brokers] if isinstance(self.brokers, str) else self.brokers
attributes["brokers"] = brokers
attributes["topics"] = topics
attributes["group"] = self.group
attributes["initial_offset"] = self.initial_offset
if self.partitions is not None:
attributes["partitions"] = self.partitions
if sasl := mlrun.datastore.utils.KafkaParameters(attributes).sasl(
usr=self.sasl_user, pwd=self.sasl_pass
):
attributes["sasl"] = sasl
return attributes
# TODO: Remove in 1.12.0
[docs]
@deprecated(
version="1.10.0",
reason=(
"This class is deprecated from mlrun 1.10.0, and will be removed in 1.12.0. "
"Use `DatastoreProfileKafkaStream` instead."
),
category=FutureWarning,
)
class DatastoreProfileKafkaSource(DatastoreProfileKafkaStream):
type: str = pydantic.v1.Field("kafka_source")
[docs]
class DatastoreProfileRabbitMQ(DatastoreProfile):
"""
Datastore profile for RabbitMQ connections.
Used to configure RabbitMQ triggers for Nuclio functions.
Example::
profile = DatastoreProfileRabbitMQ(
name="my-rabbitmq",
broker_url="amqp://rabbitmq-host:5672",
exchange_name="my-exchange",
queue_name="my-queue",
username="user",
password="secret",
)
project.register_datastore_profile(profile)
# Then use in trigger:
function.add_rabbitmq_trigger(url="ds://my-rabbitmq")
"""
type: str = "rabbitmq"
_private_attributes = ("password", "username")
broker_url: str
exchange_name: str
queue_name: str | None = None
topics: typing.Union[str, list[str]] | None = None
username: str | None = None
password: str | None = None
prefetch_count: int = 0
durable_exchange: bool = False
durable_queue: bool = False
on_error: str = "nack"
requeue_on_error: bool = False
reconnect_duration: str = "5m"
reconnect_interval: str = "15s"
num_workers: int = 1
worker_termination_timeout: str = "10s"
[docs]
def attributes(self) -> dict[str, typing.Any]:
"""Return trigger attributes dictionary."""
topics = self.topics
if isinstance(topics, str):
topics = [topics]
return {
"url": self.broker_url,
"exchange_name": self.exchange_name,
"queue_name": self.queue_name,
"topics": topics,
"username": self.username,
"password": self.password,
"prefetch_count": self.prefetch_count,
"durable_exchange": self.durable_exchange,
"durable_queue": self.durable_queue,
"on_error": self.on_error,
"requeue_on_error": self.requeue_on_error,
"reconnect_duration": self.reconnect_duration,
"reconnect_interval": self.reconnect_interval,
"num_workers": self.num_workers,
"worker_termination_timeout": self.worker_termination_timeout,
}
[docs]
class DatastoreProfileV3io(DatastoreProfile):
type: str = pydantic.v1.Field("v3io")
v3io_access_key: str | None = None
_private_attributes = "v3io_access_key"
[docs]
def url(self, subpath):
subpath = subpath.lstrip("/")
return f"v3io:///{subpath}"
[docs]
def secrets(self) -> dict:
res = {}
if self.v3io_access_key:
res["V3IO_ACCESS_KEY"] = self.v3io_access_key
return res
[docs]
class DatastoreProfileS3(DatastoreProfile):
type: str = pydantic.v1.Field("s3")
_private_attributes = ("access_key_id", "secret_key")
endpoint_url: str | None = None
force_non_anonymous: str | None = None
profile_name: str | None = None
assume_role_arn: str | None = None
access_key_id: str | None = None
secret_key: str | None = None
bucket: str
[docs]
def secrets(self) -> dict:
res = {}
if self.access_key_id:
res["AWS_ACCESS_KEY_ID"] = self.access_key_id
if self.secret_key:
res["AWS_SECRET_ACCESS_KEY"] = self.secret_key
if self.endpoint_url:
res["AWS_ENDPOINT_URL_S3"] = self.endpoint_url
if self.force_non_anonymous:
res["S3_NON_ANONYMOUS"] = self.force_non_anonymous
if self.profile_name:
res["AWS_PROFILE"] = self.profile_name
if self.assume_role_arn:
res["MLRUN_AWS_ROLE_ARN"] = self.assume_role_arn
return res
[docs]
def url(self, subpath):
# TODO: There is an inconsistency with DatastoreProfileGCS. In DatastoreProfileGCS,
# we assume that the subpath can begin without a '/' character,
# while here we assume it always starts with one.
if self.bucket:
return f"s3://{self.bucket}{subpath}"
else:
return f"s3:/{subpath}"
[docs]
class DatastoreProfileRedis(DatastoreProfile):
type: str = pydantic.v1.Field("redis")
_private_attributes = ("username", "password")
endpoint_url: str
username: str | None = None
password: str | None = None
[docs]
def url_with_credentials(self):
parsed_url = urlparse(self.endpoint_url)
# URL-encode username and password to handle special characters like @, :, /
username = quote(self.username, safe="") if self.username else None
password = quote(self.password, safe="") if self.password else None
netloc = parsed_url.hostname
if username:
if password:
netloc = f"{username}:{password}@{parsed_url.hostname}"
else:
netloc = f"{username}@{parsed_url.hostname}"
if parsed_url.port:
netloc += f":{parsed_url.port}"
new_parsed_url = ParseResult(
scheme=parsed_url.scheme,
netloc=netloc,
path=parsed_url.path,
params=parsed_url.params,
query=parsed_url.query,
fragment=parsed_url.fragment,
)
return new_parsed_url.geturl()
[docs]
def secrets(self) -> dict:
res = {}
if self.username:
res["REDIS_USER"] = self.username
if self.password:
res["REDIS_PASSWORD"] = self.password
return res
[docs]
def url(self, subpath):
return self.endpoint_url + subpath
[docs]
class DatastoreProfileDBFS(DatastoreProfile):
type: str = pydantic.v1.Field("dbfs")
_private_attributes = ("token",)
endpoint_url: str | None = None # host
token: str | None = None
[docs]
def url(self, subpath) -> str:
return f"dbfs://{subpath}"
[docs]
def secrets(self) -> dict:
res = {}
if self.token:
res["DATABRICKS_TOKEN"] = self.token
if self.endpoint_url:
res["DATABRICKS_HOST"] = self.endpoint_url
return res
[docs]
class DatastoreProfileGCS(DatastoreProfile):
type: str = pydantic.v1.Field("gcs")
_private_attributes = ("gcp_credentials",)
credentials_path: str | None = None # path to file.
gcp_credentials: typing.Union[str, dict] | None = None
bucket: str
[docs]
@pydantic.v1.validator("gcp_credentials", pre=True, always=True)
@classmethod
def convert_dict_to_json(cls, v):
if isinstance(v, dict):
return json.dumps(v)
return v
[docs]
def url(self, subpath) -> str:
# TODO: but there's something wrong with the subpath being assumed to not start with a slash here,
# but the opposite assumption is made in S3.
if subpath.startswith("/"):
# in gcs the path after schema is starts with bucket, wherefore it should not start with "/".
subpath = subpath[1:]
if self.bucket:
return (
f"gcs://{self.bucket}/{subpath}" if subpath else f"gcs://{self.bucket}"
)
else:
return f"gcs://{subpath}"
[docs]
def secrets(self) -> dict:
res = {}
if self.credentials_path:
res["GOOGLE_APPLICATION_CREDENTIALS"] = self.credentials_path
if self.gcp_credentials:
res["GCP_CREDENTIALS"] = self.gcp_credentials
return res
[docs]
class DatastoreProfileAzureBlob(DatastoreProfile):
type: str = pydantic.v1.Field("az")
_private_attributes = (
"connection_string",
"account_key",
"client_secret",
"sas_token",
"credential",
)
connection_string: str | None = None
account_name: str | None = None
account_key: str | None = None
tenant_id: str | None = None
client_id: str | None = None
client_secret: str | None = None
sas_token: str | None = None
credential: str | None = None
container: str
[docs]
def url(self, subpath) -> str:
if subpath.startswith("/"):
# in azure the path after schema is starts with container, wherefore it should not start with "/".
subpath = subpath[1:]
if self.container:
return (
f"az://{self.container}/{subpath}"
if subpath
else f"az://{self.container}"
)
else:
return f"az://{subpath}"
[docs]
def secrets(self) -> dict:
res = {}
if self.connection_string:
res["connection_string"] = self.connection_string
if self.account_name:
res["account_name"] = self.account_name
if self.account_key:
res["account_key"] = self.account_key
if self.tenant_id:
res["tenant_id"] = self.tenant_id
if self.client_id:
res["client_id"] = self.client_id
if self.client_secret:
res["client_secret"] = self.client_secret
if self.sas_token:
res["sas_token"] = self.sas_token
if self.credential:
res["credential"] = self.credential
return res
[docs]
class DatastoreProfileHdfs(DatastoreProfile):
type: str = pydantic.v1.Field("hdfs")
_private_attributes = "token"
host: str | None = None
port: int | None = None
http_port: int | None = None
user: str | None = None
[docs]
def secrets(self) -> dict:
res = {}
if self.host:
res["HDFS_HOST"] = self.host
if self.port:
res["HDFS_PORT"] = self.port
if self.port:
res["HDFS_HTTP_PORT"] = self.http_port
if self.user:
res["HDFS_USER"] = self.user
return res or None
[docs]
def url(self, subpath):
return f"webhdfs://{self.host}:{self.http_port}{subpath}"
[docs]
class DatastoreProfilePostgreSQL(DatastoreProfile):
"""
A profile that holds the required parameters for a PostgreSQL database.
PostgreSQL uses standard PostgreSQL connection parameters.
"""
type: str = pydantic.v1.Field("postgresql")
_private_attributes = ["password"]
user: str
# The password cannot be empty in real world scenarios. It's here just because of the profiles completion design.
password: str | None
host: str
port: int
database: str = "postgres" # Default PostgreSQL admin database
[docs]
def dsn(self, database: str | None = None) -> str:
"""
Get the Data Source Name of the configured PostgreSQL profile.
:param database: Optional database name to use instead of the configured one.
If None, uses the configured database.
:return: The DSN string.
"""
db = database or self.database
# URL-encode credentials and database to handle special characters
user = quote(self.user, safe="")
password = quote(self.password or "", safe="")
db_encoded = quote(db, safe="")
return f"{self.type}://{user}:{password}@{self.host}:{self.port}/{db_encoded}"
[docs]
def admin_dsn(self) -> str:
"""
Get DSN for administrative operations using the 'postgres' database.
Assumes the default 'postgres' database exists (standard PostgreSQL setup).
Used for admin tasks like creating/dropping databases.
:return: DSN pointing to the 'postgres' database.
"""
return self.dsn(database="postgres")
[docs]
@classmethod
def from_dsn(cls, dsn: str, profile_name: str) -> "DatastoreProfilePostgreSQL":
"""
Construct a PostgreSQL profile from DSN (connection string) and a name for the profile.
:param dsn: The DSN (Data Source Name) of the PostgreSQL database,
e.g.: ``"postgresql://user:password@localhost:5432/mydb"``.
:param profile_name: The new profile's name.
:return: The PostgreSQL profile.
"""
parsed_url = urlparse(dsn)
# URL-decode username, password, and database (urlparse doesn't decode them)
username = unquote(parsed_url.username) if parsed_url.username else None
password = unquote(parsed_url.password) if parsed_url.password else None
database = (
unquote(parsed_url.path.lstrip("/")) if parsed_url.path else "postgres"
)
return cls(
name=profile_name,
user=username,
password=password,
host=parsed_url.hostname,
port=parsed_url.port,
database=database or "postgres",
)
[docs]
class OpenAIProfile(DatastoreProfile):
type: str = pydantic.v1.Field("openai")
_private_attributes = "api_key"
api_key: str | None = None
organization: str | None = None
project: str | None = None
base_url: str | None = None
timeout: float | None = None
max_retries: int | None = None
batch_max_concurrent: int | None = None
[docs]
def secrets(self) -> dict:
res = {}
if self.api_key:
res["OPENAI_API_KEY"] = self.api_key
if self.organization:
res["OPENAI_ORG_ID"] = self.organization
if self.project:
res["OPENAI_PROJECT_ID"] = self.project
if self.base_url:
res["OPENAI_BASE_URL"] = self.base_url
if self.timeout:
res["OPENAI_TIMEOUT"] = self.timeout
if self.max_retries:
res["OPENAI_MAX_RETRIES"] = self.max_retries
# per batch
if self.batch_max_concurrent:
res["OPENAI_BATCH_MAX_CONCURRENT"] = self.batch_max_concurrent
return res
[docs]
def url(self, subpath):
return f"{self.type}://{subpath.lstrip('/')}"
[docs]
class HuggingFaceProfile(DatastoreProfile):
type: str = pydantic.v1.Field("huggingface")
_private_attributes = ("token", "model_kwargs")
task: str | None = None
token: str | None = None
endpoint: str | None = None
device: typing.Union[int, str] | None = None
device_map: typing.Union[str, dict[str, typing.Union[int, str]], None] = None
trust_remote_code: bool = None
max_workers: int | None = None
model_kwargs: dict[str, typing.Any] | None = None
[docs]
def secrets(self) -> dict:
keys = {
"HF_TASK": self.task,
"HF_TOKEN": self.token,
"HF_ENDPOINT": self.endpoint,
"HF_DEVICE": self.device,
"HF_DEVICE_MAP": self.device_map,
"HF_TRUST_REMOTE_CODE": self.trust_remote_code,
"HF_MAX_WORKERS": self.max_workers,
"HF_MODEL_KWARGS": self.model_kwargs,
}
return {k: v for k, v in keys.items() if v is not None}
[docs]
def url(self, subpath):
return f"{self.type}://{subpath.lstrip('/')}"
_DATASTORE_TYPE_TO_PROFILE_CLASS: dict[str, type[DatastoreProfile]] = {
"v3io": DatastoreProfileV3io,
"s3": DatastoreProfileS3,
"redis": DatastoreProfileRedis,
"basic": DatastoreProfileBasic,
"kafka_target": DatastoreProfileKafkaTarget,
"kafka_source": DatastoreProfileKafkaSource,
"kafka_stream": DatastoreProfileKafkaStream,
"dbfs": DatastoreProfileDBFS,
"gcs": DatastoreProfileGCS,
"az": DatastoreProfileAzureBlob,
"hdfs": DatastoreProfileHdfs,
"postgresql": DatastoreProfilePostgreSQL,
"config": ConfigProfile,
"openai": OpenAIProfile,
"huggingface": HuggingFaceProfile,
"rabbitmq": DatastoreProfileRabbitMQ,
}
[docs]
class DatastoreProfile2Json(pydantic.v1.BaseModel):
@staticmethod
def _to_json(attributes):
# First, base64 encode the values
encoded_dict = {
k: base64.b64encode(str(v).encode()).decode() for k, v in attributes.items()
}
# Then, return the dictionary as a JSON string with no spaces
return json.dumps(encoded_dict).replace(" ", "")
[docs]
@staticmethod
def get_json_public(profile: DatastoreProfile) -> str:
return DatastoreProfile2Json._to_json(
{
k: v
for k, v in profile.dict().items()
if str(k) not in profile._private_attributes
}
)
[docs]
@staticmethod
def get_json_private(profile: DatastoreProfile) -> str:
return DatastoreProfile2Json._to_json(
{
k: v
for k, v in profile.dict().items()
if str(k) in profile._private_attributes
}
)
[docs]
@staticmethod
def create_from_json(public_json: str, private_json: str = "{}"):
attributes = json.loads(public_json)
attributes_public = {
k: base64.b64decode(str(v).encode()).decode() for k, v in attributes.items()
}
attributes = json.loads(private_json)
attributes_private = {
k: base64.b64decode(str(v).encode()).decode() for k, v in attributes.items()
}
decoded_dict = merge(attributes_public, attributes_private)
def safe_literal_eval(value):
try:
return ast.literal_eval(value)
except (ValueError, SyntaxError):
return value
decoded_dict = {k: safe_literal_eval(v) for k, v in decoded_dict.items()}
datastore_type = decoded_dict.get("type")
ds_profile_factory = _DATASTORE_TYPE_TO_PROFILE_CLASS
if datastore_type in ds_profile_factory:
return ds_profile_factory[datastore_type].parse_obj(decoded_dict)
else:
if datastore_type:
reason = f"unexpected type '{decoded_dict['type']}'"
else:
reason = "missing type"
raise mlrun.errors.MLRunInvalidArgumentError(
f"Datastore profile failed to create from json due to {reason}"
)
[docs]
def datastore_profile_read(url, project_name="", secrets: dict | None = None):
"""
Read and retrieve a datastore profile from a given URL.
This function retrieves a datastore profile either from temporary client storage,
or from the MLRun database. It handles both client-side and server-side profile formats
and performs necessary conversions.
Args:
url (str): A URL with 'ds' scheme pointing to the datastore profile
(e.g., 'ds://profile-name').
project_name (str, optional): The project name where the profile is stored.
Defaults to MLRun's active project.
secrets (dict, optional): Dictionary containing secrets needed for profile retrieval.
Returns:
DatastoreProfile: The retrieved datastore profile object.
Raises:
MLRunInvalidArgumentError: In the following cases:
- If the URL scheme is not 'ds'
- If the profile cannot be retrieved from either server or local environment
Note:
When running from a client environment (outside MLRun pods), private profile information
is not accessible. In this case, use register_temporary_client_datastore_profile() to
register the profile with credentials for your local session. When running inside MLRun
pods, the private information is automatically available and no temporary registration is needed.
"""
parsed_url = urlparse(url)
if parsed_url.scheme.lower() != "ds":
raise mlrun.errors.MLRunInvalidArgumentError(
f"resource URL '{url}' cannot be read as a datastore profile because its scheme is not 'ds'"
)
profile_name = parsed_url.hostname
project_name = project_name or mlrun.mlconf.active_project
datastore = TemporaryClientDatastoreProfiles().get(profile_name)
if datastore:
return datastore
public_profile = mlrun.db.get_run_db().get_datastore_profile(
profile_name, project_name
)
# The mlrun.db.get_run_db().get_datastore_profile() function is capable of returning
# two distinct types of objects based on its execution context.
# If it operates from the client or within the pod (which is the common scenario),
# it yields an instance of `mlrun.datastore.DatastoreProfile`. Conversely,
# when executed on the server with a direct call to `sqldb`, it produces an instance of
# mlrun.common.schemas.DatastoreProfile.
# In the latter scenario, an extra conversion step is required to transform the object
# into mlrun.datastore.DatastoreProfile.
if isinstance(public_profile, mlrun.common.schemas.DatastoreProfile):
public_profile = DatastoreProfile2Json.create_from_json(
public_json=public_profile.object
)
project_ds_name_private = DatastoreProfile.generate_secret_key(
profile_name, project_name
)
private_body = get_secret_or_env(project_ds_name_private, secret_provider=secrets)
if not public_profile or not private_body:
raise mlrun.errors.MLRunNotFoundError(
f"Unable to retrieve the datastore profile '{url}' from either the server or local environment. "
"Make sure the profile is registered correctly, or if running in a local environment, "
"use register_temporary_client_datastore_profile() to provide credentials locally."
)
datastore = DatastoreProfile2Json.create_from_json(
public_json=DatastoreProfile2Json.get_json_public(public_profile),
private_json=private_body,
)
return datastore
[docs]
def register_temporary_client_datastore_profile(profile: DatastoreProfile):
"""Register the datastore profile.
This profile is temporary and remains valid only for the duration of the caller's session.
It's beneficial for testing purposes.
"""
TemporaryClientDatastoreProfiles().add(profile)
[docs]
def remove_temporary_client_datastore_profile(profile_name: str):
TemporaryClientDatastoreProfiles().remove(profile_name)