"""Note: This is not used in the current implementation. It is a work in
progress and a placeholder for future work.
"""
import typing as _t
from action_triggers.base.config import ConnectionBase
from action_triggers.config_required_fields import RequiredFieldBase
[docs]
def validate_required_keys(
required_fields: _t.Sequence[RequiredFieldBase],
context: _t.Dict[str, _t.Any],
error_handler: _t.Callable[[str, str], None],
) -> None:
"""Validate that the required keys are present in the provided context.
:param required_fields: The required fields to check.
:param context: The context to check.
:param error_handler: The function to call if the required fields are not
"""
for field in required_fields:
if not field.check(context):
error_handler(field.key_repr, field.error_msg)
[docs]
def validate_context_not_overwritten(
context: _t.Union[_t.Dict[str, _t.Any], None],
user_context: _t.Dict[str, _t.Any],
error_handler: _t.Callable[[str, str], None],
) -> None:
"""Validate that the context is not overwritten.
:param context: The context to check.
:param user_context: The user context to check.
:param error_handler: The function to call if the required fields are not
"""
if context is None:
return
for key, value in context.items():
if key in user_context and user_context[key] != value:
error_handler(key, f"{key} cannot be overwritten.")
[docs]
class ConnectionValidationMixin:
"""The core validation class for the configuration. This class should
cover the basic validation requirements for any connection.
"""
required_conn_detail_fields: _t.Sequence[RequiredFieldBase]
required_params_fields: _t.Sequence[RequiredFieldBase]
config: _t.Dict[str, _t.Any]
conn_details: _t.Dict[str, _t.Any]
params: _t.Dict[str, _t.Any]
_user_conn_details: _t.Dict[str, _t.Any]
_user_params: _t.Dict[str, _t.Any]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.validate()
[docs]
def validate(self) -> None:
"""Validate the configuration."""
self.validate_required_conn_details()
self.validate_required_params()
self.validate_connection_details_not_overwritten()
self.validate_params_not_overwritten()
self._errors.is_valid(raise_exception=True) # type: ignore[attr-defined] # noqa: E501
[docs]
def validate_connection_details_not_overwritten(self) -> None:
"""Validate that the base connection details are not overwritten."""
validate_context_not_overwritten(
self.config.get("conn_details"),
self._user_conn_details,
self._errors.add_connection_params_error, # type: ignore[attr-defined] # noqa: E501
)
[docs]
def validate_params_not_overwritten(self) -> None:
"""Validate that the base parameters are not overwritten."""
validate_context_not_overwritten(
self.config.get("params"),
self._user_params,
self._errors.add_params_error, # type: ignore[attr-defined]
)
[docs]
def validate_required_conn_details(self) -> None:
"""Validate that the required connection details are present."""
validate_required_keys(
self.required_conn_detail_fields,
self.conn_details,
self._errors.add_connection_params_error, # type: ignore[attr-defined] # noqa: E501
)
[docs]
def validate_required_params(self) -> None:
"""Validate that the required parameters are present."""
validate_required_keys(
self.required_params_fields,
self.params,
self._errors.add_params_error, # type: ignore[attr-defined]
)
[docs]
class ConnectionCore(ConnectionValidationMixin, ConnectionBase):
"""The core connection class with some common validation that should be
applied to all connections.
"""