Source code for mlrun.secrets

# Copyright 2023 Iguazio
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from ast import literal_eval
from collections.abc import Callable
from os import environ
from typing import Union

import mlrun.auth.utils
import mlrun.utils.helpers
from mlrun.config import is_running_as_api

from .utils import AzureVaultStore, list2dict


class SecretsStore:
    def __init__(self):
        self._secrets = {}
        # Hidden secrets' value must not be serialized. Only the keys can be. These secrets are retrieved externally,
        # for example from Vault, and when adding their source they will be retrieved from the external source.
        self._hidden_sources = []
        self._hidden_secrets = {}

    @classmethod
    def from_list(cls, src_list: list):
        store = cls()
        if src_list and isinstance(src_list, list):
            for src in src_list:
                store.add_source(src["kind"], src.get("source"), src.get("prefix", ""))
        return store

    def to_dict(self, struct):
        pass

    def add_source(self, kind, source="", prefix=""):
        if kind == "inline":
            if isinstance(source, str):
                source = literal_eval(source)
            if not isinstance(source, dict):
                raise ValueError("inline secrets must be of type dict")
            for k, v in source.items():
                self._secrets[prefix + k] = str(v)

        elif kind == "file":
            # Ensure files cannot be open from inside the API
            if is_running_as_api():
                raise RuntimeError(
                    "add_source of kind 'file' is not allowed from the API"
                )
            with open(source) as fp:
                lines = fp.read().splitlines()
                secrets_dict = list2dict(lines)
                for k, v in secrets_dict.items():
                    self._secrets[prefix + k] = str(v)

        elif kind == "env":
            for key in source.split(","):
                k = key.strip()
                self._secrets[prefix + k] = environ.get(k)
        # TODO: Vault: uncomment when vault returns to be relevant
        # elif kind == "vault":
        #     if isinstance(source, str):
        #         source = literal_eval(source)
        #     if not isinstance(source, dict):
        #         raise ValueError("vault secrets must be of type dict")
        #
        #     for key, value in self.vault.get_secrets(
        #         source["secrets"],
        #         user=source.get("user"),
        #         project=source.get("project"),
        #     ).items():
        #         self._hidden_secrets[prefix + key] = value
        #     self._hidden_sources.append({"kind": kind, "source": source})
        elif kind == "azure_vault":
            if isinstance(source, str):
                source = literal_eval(source)
            if not isinstance(source, dict):
                raise ValueError("Azure vault secrets must be of type dict")
            if "name" not in source:
                raise ValueError(
                    "'name' must be provided in the source to define an Azure vault"
                )

            azure_vault = AzureVaultStore(source["name"])
            for key, value in azure_vault.get_secrets(source["secrets"]).items():
                self._hidden_secrets[prefix + key] = value
            self._hidden_sources.append({"kind": kind, "source": source})
        elif kind == "kubernetes":
            if isinstance(source, str):
                source = literal_eval(source)
            if not isinstance(source, list):
                raise ValueError("k8s secrets must be of type list")
            for secret in source:
                env_value = environ.get(self.k8s_env_variable_name_for_secret(secret))
                if env_value:
                    self._hidden_secrets[prefix + secret] = env_value
            self._hidden_sources.append({"kind": kind, "source": source})

    def get(self, key, default=None):
        return (
            self._secrets.get(key)
            or self._hidden_secrets.get(key)
            or environ.get(self.k8s_env_variable_name_for_secret(key))
            or default
        )

    def items(self):
        res = self._secrets.copy()
        if self._hidden_secrets:
            res.update(self._hidden_secrets)
        return res.items()

    def to_serial(self):
        # todo: use encryption
        res = [{"kind": "inline", "source": self._secrets.copy()}]
        if self._hidden_sources:
            for src in self._hidden_sources.copy():
                res.append(src)
        return res

    def has_vault_source(self):
        return any(source["kind"] == "vault" for source in self._hidden_sources)

    def has_azure_vault_source(self):
        return any(source["kind"] == "azure_vault" for source in self._hidden_sources)

    def get_azure_vault_k8s_secret(self):
        for source in self._hidden_sources:
            if source["kind"] == "azure_vault":
                return source["source"].get("k8s_secret", None)

    @staticmethod
    def k8s_env_variable_name_for_secret(secret_name):
        from mlrun.config import config

        return config.secret_stores.kubernetes.env_variable_prefix + secret_name

    def get_k8s_secrets(self):
        for source in self._hidden_sources:
            if source["kind"] == "kubernetes":
                return {
                    secret: self.k8s_env_variable_name_for_secret(secret)
                    for secret in source["source"]
                }
        return None


[docs] def get_secret_or_env( key: str, secret_provider: Union[dict, SecretsStore, Callable, None] = None, default: str | None = None, prefix: str | None = None, ) -> str | None: """Retrieve value of a secret, either from a user-provided secret store, or from environment variables. The function will retrieve a secret value, attempting to find it according to the following order: 1. If `secret_provider` was provided, will attempt to retrieve the secret from it 2. If an MLRun `SecretsStore` was provided, query it for the secret key 3. An environment variable with the same key 4. An MLRun-generated env. variable, mounted from a project secret (to be used in MLRun runtimes) 5. The default value Also supports discovering the value inside any environment variable that contains a JSON-encoded list of dicts with fields: {'name': 'KEY', 'value': 'VAL', 'value_from': ...}. This fallback is applied after checking normal environment variables and before returning the default. Example:: secrets = {"KEY1": "VALUE1"} secret = get_secret_or_env("KEY1", secret_provider=secrets) # Using a function to retrieve a secret def my_secret_provider(key): # some internal logic to retrieve secret return value secret = get_secret_or_env( "KEY1", secret_provider=my_secret_provider, default="TOO-MANY-SECRETS" ) :param key: Secret key to look for :param secret_provider: Dictionary, callable or `SecretsStore` to extract the secret value from. If using a callable, it must use the signature `callable(key:str)` :param default: Default value to return if secret was not available through any other means :param prefix: When passed, the prefix is added to the secret key. :return: The secret value if found in any of the sources, or `default` if provided. """ if prefix: key = f"{prefix}_{key}" if secret_provider: if isinstance(secret_provider, dict | SecretsStore): secret_value = secret_provider.get(key) else: secret_value = secret_provider(key) if secret_value: return secret_value direct_environment_value = environ.get(key) if direct_environment_value: return direct_environment_value json_list_value = _find_value_in_json_env_lists(key) if json_list_value is not None: return json_list_value mlrun_env_key = SecretsStore.k8s_env_variable_name_for_secret(key) mlrun_env_value = environ.get(mlrun_env_key) if mlrun_env_value: return mlrun_env_value return default
def _find_value_in_json_env_lists( secret_name: str, ) -> str | None: """ Scan all environment variables. If any env var contains a JSON-encoded list of dicts shaped like {'name': str, 'value': str|None, 'value_from': ...}, return the 'value' for the entry whose 'name' matches secret_name. """ for environment_variable_value in environ.values(): if not environment_variable_value or not isinstance( environment_variable_value, str ): continue # Fast precheck to skip obvious non-JSON strings first_char = environment_variable_value.lstrip()[:1] if first_char not in ("[", "{"): continue try: parsed_value = json.loads(environment_variable_value) except ValueError: continue if isinstance(parsed_value, list): for entry in parsed_value: if isinstance(entry, dict) and entry.get("name") == secret_name: value_in_entry = entry.get("value") # Match original semantics: empty string is treated as "not found" if value_in_entry: return value_in_entry return None
[docs] @mlrun.utils.iguazio_v4_only def sync_secret_tokens() -> None: """ Synchronize local secret tokens with the backend. Doesn't sync when running from a runtime. This function: 1. Reads the local token file (defaults to `mlrun.mlconf.auth_with_oauth_token.token_file` value). 2. Validates its content and resolves the token currently in use into a `SecretToken` object. 3. Uploads the token to the backend. 4. Logs a warning if the token was updated on the backend due to a newer expiration time found locally. """ # Do not sync tokens from the file when using the offline token environment variable. # The offline token from the env var takes precedence over the file. # Using the env var is not the recommended approach, and tokens from the env var # will not be saved as secrets in the backend. if os.getenv("MLRUN_AUTH_OFFLINE_TOKEN") or mlrun.utils.is_running_in_runtime(): return # The import is needed here to prevent a circular import, since this method is called from the mlrun.db connection. from mlrun.db import get_run_db secret_tokens = mlrun.auth.utils.load_and_prepare_secret_tokens( auth_user_id=get_run_db().token_provider.authenticated_user_id, raise_on_error=False, ) if not secret_tokens: raise mlrun.errors.MLRunRuntimeError( "Authentication succeeded, but the token was not synced to the backend " "since no valid token was found after validation. " "Check your token file for malformed, expired, or mismatched tokens: " f"{mlrun.mlconf.auth_with_oauth_token.token_file}" ) # The log_warning=False flag ensures the SDK doesn't log # unnecessary warnings about local file updates, since # this method reads from the file, not updates it. get_run_db().store_secret_token(secret_tokens[0], log_warning=False)