"""This module includes classes and functions designed specifically for use with the mypy plugin.""" from __future__ import annotations import sys from configparser import ConfigParser from typing import Any, Callable, Iterator from mypy.errorcodes import ErrorCode from mypy.expandtype import expand_type, expand_type_by_instance from mypy.nodes import ( ARG_NAMED, ARG_NAMED_OPT, ARG_OPT, ARG_POS, ARG_STAR2, INVARIANT, MDEF, Argument, AssignmentStmt, Block, CallExpr, ClassDef, Context, Decorator, DictExpr, EllipsisExpr, Expression, FuncDef, IfStmt, JsonDict, MemberExpr, NameExpr, PassStmt, PlaceholderNode, RefExpr, Statement, StrExpr, SymbolTableNode, TempNode, TypeAlias, TypeInfo, Var, ) from mypy.options import Options from mypy.plugin import ( CheckerPluginInterface, ClassDefContext, FunctionContext, MethodContext, Plugin, ReportConfigContext, SemanticAnalyzerPluginInterface, ) from mypy.plugins import dataclasses from mypy.plugins.common import ( deserialize_and_fixup_type, ) from mypy.semanal import set_callable_name from mypy.server.trigger import make_wildcard_trigger from mypy.state import state from mypy.typeops import map_type_from_supertype from mypy.types import ( AnyType, CallableType, Instance, NoneType, Overloaded, Type, TypeOfAny, TypeType, TypeVarType, UnionType, get_proper_type, ) from mypy.typevars import fill_typevars from mypy.util import get_unique_redefinition_name from mypy.version import __version__ as mypy_version from pydantic._internal import _fields from pydantic.version import parse_mypy_version try: from mypy.types import TypeVarDef # type: ignore[attr-defined] except ImportError: # pragma: no cover # Backward-compatible with TypeVarDef from Mypy 0.930. from mypy.types import TypeVarType as TypeVarDef CONFIGFILE_KEY = 'pydantic-mypy' METADATA_KEY = 'pydantic-mypy-metadata' BASEMODEL_FULLNAME = 'pydantic.main.BaseModel' BASESETTINGS_FULLNAME = 'pydantic_settings.main.BaseSettings' ROOT_MODEL_FULLNAME = 'pydantic.root_model.RootModel' MODEL_METACLASS_FULLNAME = 'pydantic._internal._model_construction.ModelMetaclass' FIELD_FULLNAME = 'pydantic.fields.Field' DATACLASS_FULLNAME = 'pydantic.dataclasses.dataclass' MODEL_VALIDATOR_FULLNAME = 'pydantic.functional_validators.model_validator' DECORATOR_FULLNAMES = { 'pydantic.functional_validators.field_validator', 'pydantic.functional_validators.model_validator', 'pydantic.functional_serializers.serializer', 'pydantic.functional_serializers.model_serializer', 'pydantic.deprecated.class_validators.validator', 'pydantic.deprecated.class_validators.root_validator', } MYPY_VERSION_TUPLE = parse_mypy_version(mypy_version) BUILTINS_NAME = 'builtins' if MYPY_VERSION_TUPLE >= (0, 930) else '__builtins__' # Increment version if plugin changes and mypy caches should be invalidated __version__ = 2 def plugin(version: str) -> type[Plugin]: """`version` is the mypy version string. We might want to use this to print a warning if the mypy version being used is newer, or especially older, than we expect (or need). Args: version: The mypy version string. Return: The Pydantic mypy plugin type. """ return PydanticPlugin class PydanticPlugin(Plugin): """The Pydantic mypy plugin.""" def __init__(self, options: Options) -> None: self.plugin_config = PydanticPluginConfig(options) self._plugin_data = self.plugin_config.to_data() super().__init__(options) def get_base_class_hook(self, fullname: str) -> Callable[[ClassDefContext], bool] | None: """Update Pydantic model class.""" sym = self.lookup_fully_qualified(fullname) if sym and isinstance(sym.node, TypeInfo): # pragma: no branch # No branching may occur if the mypy cache has not been cleared if any(base.fullname == BASEMODEL_FULLNAME for base in sym.node.mro): return self._pydantic_model_class_maker_callback return None def get_metaclass_hook(self, fullname: str) -> Callable[[ClassDefContext], None] | None: """Update Pydantic `ModelMetaclass` definition.""" if fullname == MODEL_METACLASS_FULLNAME: return self._pydantic_model_metaclass_marker_callback return None def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type] | None: """Adjust the return type of the `Field` function.""" sym = self.lookup_fully_qualified(fullname) if sym and sym.fullname == FIELD_FULLNAME: return self._pydantic_field_callback return None def get_method_hook(self, fullname: str) -> Callable[[MethodContext], Type] | None: """Adjust return type of `from_orm` method call.""" if fullname.endswith('.from_orm'): return from_attributes_callback return None def get_class_decorator_hook(self, fullname: str) -> Callable[[ClassDefContext], None] | None: """Mark pydantic.dataclasses as dataclass. Mypy version 1.1.1 added support for `@dataclass_transform` decorator. """ if fullname == DATACLASS_FULLNAME and MYPY_VERSION_TUPLE < (1, 1): return dataclasses.dataclass_class_maker_callback # type: ignore[return-value] return None def report_config_data(self, ctx: ReportConfigContext) -> dict[str, Any]: """Return all plugin config data. Used by mypy to determine if cache needs to be discarded. """ return self._plugin_data def _pydantic_model_class_maker_callback(self, ctx: ClassDefContext) -> bool: transformer = PydanticModelTransformer(ctx.cls, ctx.reason, ctx.api, self.plugin_config) return transformer.transform() def _pydantic_model_metaclass_marker_callback(self, ctx: ClassDefContext) -> None: """Reset dataclass_transform_spec attribute of ModelMetaclass. Let the plugin handle it. This behavior can be disabled if 'debug_dataclass_transform' is set to True', for testing purposes. """ if self.plugin_config.debug_dataclass_transform: return info_metaclass = ctx.cls.info.declared_metaclass assert info_metaclass, "callback not passed from 'get_metaclass_hook'" if getattr(info_metaclass.type, 'dataclass_transform_spec', None): info_metaclass.type.dataclass_transform_spec = None def _pydantic_field_callback(self, ctx: FunctionContext) -> Type: """Extract the type of the `default` argument from the Field function, and use it as the return type. In particular: * Check whether the default and default_factory argument is specified. * Output an error if both are specified. * Retrieve the type of the argument which is specified, and use it as return type for the function. """ default_any_type = ctx.default_return_type assert ctx.callee_arg_names[0] == 'default', '"default" is no longer first argument in Field()' assert ctx.callee_arg_names[1] == 'default_factory', '"default_factory" is no longer second argument in Field()' default_args = ctx.args[0] default_factory_args = ctx.args[1] if default_args and default_factory_args: error_default_and_default_factory_specified(ctx.api, ctx.context) return default_any_type if default_args: default_type = ctx.arg_types[0][0] default_arg = default_args[0] # Fallback to default Any type if the field is required if not isinstance(default_arg, EllipsisExpr): return default_type elif default_factory_args: default_factory_type = ctx.arg_types[1][0] # Functions which use `ParamSpec` can be overloaded, exposing the callable's types as a parameter # Pydantic calls the default factory without any argument, so we retrieve the first item if isinstance(default_factory_type, Overloaded): default_factory_type = default_factory_type.items[0] if isinstance(default_factory_type, CallableType): ret_type = default_factory_type.ret_type # mypy doesn't think `ret_type` has `args`, you'd think mypy should know, # add this check in case it varies by version args = getattr(ret_type, 'args', None) if args: if all(isinstance(arg, TypeVarType) for arg in args): # Looks like the default factory is a type like `list` or `dict`, replace all args with `Any` ret_type.args = tuple(default_any_type for _ in args) # type: ignore[attr-defined] return ret_type return default_any_type class PydanticPluginConfig: """A Pydantic mypy plugin config holder. Attributes: init_forbid_extra: Whether to add a `**kwargs` at the end of the generated `__init__` signature. init_typed: Whether to annotate fields in the generated `__init__`. warn_required_dynamic_aliases: Whether to raise required dynamic aliases error. debug_dataclass_transform: Whether to not reset `dataclass_transform_spec` attribute of `ModelMetaclass` for testing purposes. """ __slots__ = ( 'init_forbid_extra', 'init_typed', 'warn_required_dynamic_aliases', 'debug_dataclass_transform', ) init_forbid_extra: bool init_typed: bool warn_required_dynamic_aliases: bool debug_dataclass_transform: bool # undocumented def __init__(self, options: Options) -> None: if options.config_file is None: # pragma: no cover return toml_config = parse_toml(options.config_file) if toml_config is not None: config = toml_config.get('tool', {}).get('pydantic-mypy', {}) for key in self.__slots__: setting = config.get(key, False) if not isinstance(setting, bool): raise ValueError(f'Configuration value must be a boolean for key: {key}') setattr(self, key, setting) else: plugin_config = ConfigParser() plugin_config.read(options.config_file) for key in self.__slots__: setting = plugin_config.getboolean(CONFIGFILE_KEY, key, fallback=False) setattr(self, key, setting) def to_data(self) -> dict[str, Any]: """Returns a dict of config names to their values.""" return {key: getattr(self, key) for key in self.__slots__} def from_attributes_callback(ctx: MethodContext) -> Type: """Raise an error if from_attributes is not enabled.""" model_type: Instance ctx_type = ctx.type if isinstance(ctx_type, TypeType): ctx_type = ctx_type.item if isinstance(ctx_type, CallableType) and isinstance(ctx_type.ret_type, Instance): model_type = ctx_type.ret_type # called on the class elif isinstance(ctx_type, Instance): model_type = ctx_type # called on an instance (unusual, but still valid) else: # pragma: no cover detail = f'ctx.type: {ctx_type} (of type {ctx_type.__class__.__name__})' error_unexpected_behavior(detail, ctx.api, ctx.context) return ctx.default_return_type pydantic_metadata = model_type.type.metadata.get(METADATA_KEY) if pydantic_metadata is None: return ctx.default_return_type if not any(base.fullname == BASEMODEL_FULLNAME for base in model_type.type.mro): # not a Pydantic v2 model return ctx.default_return_type from_attributes = pydantic_metadata.get('config', {}).get('from_attributes') if from_attributes is not True: error_from_attributes(model_type.type.name, ctx.api, ctx.context) return ctx.default_return_type class PydanticModelField: """Based on mypy.plugins.dataclasses.DataclassAttribute.""" def __init__( self, name: str, alias: str | None, is_frozen: bool, has_dynamic_alias: bool, has_default: bool, strict: bool | None, line: int, column: int, type: Type | None, info: TypeInfo, ): self.name = name self.alias = alias self.is_frozen = is_frozen self.has_dynamic_alias = has_dynamic_alias self.has_default = has_default self.strict = strict self.line = line self.column = column self.type = type self.info = info def to_argument( self, current_info: TypeInfo, typed: bool, model_strict: bool, force_optional: bool, use_alias: bool, api: SemanticAnalyzerPluginInterface, force_typevars_invariant: bool, is_root_model_root: bool, ) -> Argument: """Based on mypy.plugins.dataclasses.DataclassAttribute.to_argument.""" variable = self.to_var(current_info, api, use_alias, force_typevars_invariant) strict = model_strict if self.strict is None else self.strict if typed or strict: type_annotation = self.expand_type(current_info, api) else: type_annotation = AnyType(TypeOfAny.explicit) return Argument( variable=variable, type_annotation=type_annotation, initializer=None, kind=ARG_OPT if is_root_model_root else (ARG_NAMED_OPT if force_optional or self.has_default else ARG_NAMED), ) def expand_type( self, current_info: TypeInfo, api: SemanticAnalyzerPluginInterface, force_typevars_invariant: bool = False ) -> Type | None: """Based on mypy.plugins.dataclasses.DataclassAttribute.expand_type.""" # The getattr in the next line is used to prevent errors in legacy versions of mypy without this attribute if force_typevars_invariant: # In some cases, mypy will emit an error "Cannot use a covariant type variable as a parameter" # To prevent that, we add an option to replace typevars with invariant ones while building certain # method signatures (in particular, `__init__`). There may be a better way to do this, if this causes # us problems in the future, we should look into why the dataclasses plugin doesn't have this issue. if isinstance(self.type, TypeVarType): modified_type = self.type.copy_modified() modified_type.variance = INVARIANT self.type = modified_type if self.type is not None and getattr(self.info, 'self_type', None) is not None: # In general, it is not safe to call `expand_type()` during semantic analyzis, # however this plugin is called very late, so all types should be fully ready. # Also, it is tricky to avoid eager expansion of Self types here (e.g. because # we serialize attributes). with state.strict_optional_set(api.options.strict_optional): filled_with_typevars = fill_typevars(current_info) if force_typevars_invariant: for arg in filled_with_typevars.args: if isinstance(arg, TypeVarType): arg.variance = INVARIANT return expand_type(self.type, {self.info.self_type.id: filled_with_typevars}) return self.type def to_var( self, current_info: TypeInfo, api: SemanticAnalyzerPluginInterface, use_alias: bool, force_typevars_invariant: bool = False, ) -> Var: """Based on mypy.plugins.dataclasses.DataclassAttribute.to_var.""" if use_alias and self.alias is not None: name = self.alias else: name = self.name return Var(name, self.expand_type(current_info, api, force_typevars_invariant)) def serialize(self) -> JsonDict: """Based on mypy.plugins.dataclasses.DataclassAttribute.serialize.""" assert self.type return { 'name': self.name, 'alias': self.alias, 'is_frozen': self.is_frozen, 'has_dynamic_alias': self.has_dynamic_alias, 'has_default': self.has_default, 'strict': self.strict, 'line': self.line, 'column': self.column, 'type': self.type.serialize(), } @classmethod def deserialize(cls, info: TypeInfo, data: JsonDict, api: SemanticAnalyzerPluginInterface) -> PydanticModelField: """Based on mypy.plugins.dataclasses.DataclassAttribute.deserialize.""" data = data.copy() typ = deserialize_and_fixup_type(data.pop('type'), api) return cls(type=typ, info=info, **data) def expand_typevar_from_subtype(self, sub_type: TypeInfo, api: SemanticAnalyzerPluginInterface) -> None: """Expands type vars in the context of a subtype when an attribute is inherited from a generic super type. """ if self.type is not None: with state.strict_optional_set(api.options.strict_optional): self.type = map_type_from_supertype(self.type, sub_type, self.info) class PydanticModelClassVar: """Based on mypy.plugins.dataclasses.DataclassAttribute. ClassVars are ignored by subclasses. Attributes: name: the ClassVar name """ def __init__(self, name): self.name = name @classmethod def deserialize(cls, data: JsonDict) -> PydanticModelClassVar: """Based on mypy.plugins.dataclasses.DataclassAttribute.deserialize.""" data = data.copy() return cls(**data) def serialize(self) -> JsonDict: """Based on mypy.plugins.dataclasses.DataclassAttribute.serialize.""" return { 'name': self.name, } class PydanticModelTransformer: """Transform the BaseModel subclass according to the plugin settings. Attributes: tracked_config_fields: A set of field configs that the plugin has to track their value. """ tracked_config_fields: set[str] = { 'extra', 'frozen', 'from_attributes', 'populate_by_name', 'alias_generator', 'strict', } def __init__( self, cls: ClassDef, reason: Expression | Statement, api: SemanticAnalyzerPluginInterface, plugin_config: PydanticPluginConfig, ) -> None: self._cls = cls self._reason = reason self._api = api self.plugin_config = plugin_config def transform(self) -> bool: """Configures the BaseModel subclass according to the plugin settings. In particular: * determines the model config and fields, * adds a fields-aware signature for the initializer and construct methods * freezes the class if frozen = True * stores the fields, config, and if the class is settings in the mypy metadata for access by subclasses """ info = self._cls.info is_root_model = any(ROOT_MODEL_FULLNAME in base.fullname for base in info.mro[:-1]) config = self.collect_config() fields, class_vars = self.collect_fields_and_class_vars(config, is_root_model) if fields is None or class_vars is None: # Some definitions are not ready. We need another pass. return False for field in fields: if field.type is None: return False is_settings = any(base.fullname == BASESETTINGS_FULLNAME for base in info.mro[:-1]) self.add_initializer(fields, config, is_settings, is_root_model) self.add_model_construct_method(fields, config, is_settings, is_root_model) self.set_frozen(fields, self._api, frozen=config.frozen is True) self.adjust_decorator_signatures() info.metadata[METADATA_KEY] = { 'fields': {field.name: field.serialize() for field in fields}, 'class_vars': {class_var.name: class_var.serialize() for class_var in class_vars}, 'config': config.get_values_dict(), } return True def adjust_decorator_signatures(self) -> None: """When we decorate a function `f` with `pydantic.validator(...)`, `pydantic.field_validator` or `pydantic.serializer(...)`, mypy sees `f` as a regular method taking a `self` instance, even though pydantic internally wraps `f` with `classmethod` if necessary. Teach mypy this by marking any function whose outermost decorator is a `validator()`, `field_validator()` or `serializer()` call as a `classmethod`. """ for sym in self._cls.info.names.values(): if isinstance(sym.node, Decorator): first_dec = sym.node.original_decorators[0] if ( isinstance(first_dec, CallExpr) and isinstance(first_dec.callee, NameExpr) and first_dec.callee.fullname in DECORATOR_FULLNAMES # @model_validator(mode="after") is an exception, it expects a regular method and not ( first_dec.callee.fullname == MODEL_VALIDATOR_FULLNAME and any( first_dec.arg_names[i] == 'mode' and isinstance(arg, StrExpr) and arg.value == 'after' for i, arg in enumerate(first_dec.args) ) ) ): # TODO: Only do this if the first argument of the decorated function is `cls` sym.node.func.is_class = True def collect_config(self) -> ModelConfigData: # noqa: C901 (ignore complexity) """Collects the values of the config attributes that are used by the plugin, accounting for parent classes.""" cls = self._cls config = ModelConfigData() has_config_kwargs = False has_config_from_namespace = False # Handle `class MyModel(BaseModel, =, ...):` for name, expr in cls.keywords.items(): config_data = self.get_config_update(name, expr) if config_data: has_config_kwargs = True config.update(config_data) # Handle `model_config` stmt: Statement | None = None for stmt in cls.defs.body: if not isinstance(stmt, (AssignmentStmt, ClassDef)): continue if isinstance(stmt, AssignmentStmt): lhs = stmt.lvalues[0] if not isinstance(lhs, NameExpr) or lhs.name != 'model_config': continue if isinstance(stmt.rvalue, CallExpr): # calls to `dict` or `ConfigDict` for arg_name, arg in zip(stmt.rvalue.arg_names, stmt.rvalue.args): if arg_name is None: continue config.update(self.get_config_update(arg_name, arg, lax_extra=True)) elif isinstance(stmt.rvalue, DictExpr): # dict literals for key_expr, value_expr in stmt.rvalue.items: if not isinstance(key_expr, StrExpr): continue config.update(self.get_config_update(key_expr.value, value_expr)) elif isinstance(stmt, ClassDef): if stmt.name != 'Config': # 'deprecated' Config-class continue for substmt in stmt.defs.body: if not isinstance(substmt, AssignmentStmt): continue lhs = substmt.lvalues[0] if not isinstance(lhs, NameExpr): continue config.update(self.get_config_update(lhs.name, substmt.rvalue)) if has_config_kwargs: self._api.fail( 'Specifying config in two places is ambiguous, use either Config attribute or class kwargs', cls, ) break has_config_from_namespace = True if has_config_kwargs or has_config_from_namespace: if ( stmt and config.has_alias_generator and not config.populate_by_name and self.plugin_config.warn_required_dynamic_aliases ): error_required_dynamic_aliases(self._api, stmt) for info in cls.info.mro[1:]: # 0 is the current class if METADATA_KEY not in info.metadata: continue # Each class depends on the set of fields in its ancestors self._api.add_plugin_dependency(make_wildcard_trigger(info.fullname)) for name, value in info.metadata[METADATA_KEY]['config'].items(): config.setdefault(name, value) return config def collect_fields_and_class_vars( self, model_config: ModelConfigData, is_root_model: bool ) -> tuple[list[PydanticModelField] | None, list[PydanticModelClassVar] | None]: """Collects the fields for the model, accounting for parent classes.""" cls = self._cls # First, collect fields and ClassVars belonging to any class in the MRO, ignoring duplicates. # # We iterate through the MRO in reverse because attrs defined in the parent must appear # earlier in the attributes list than attrs defined in the child. See: # https://docs.python.org/3/library/dataclasses.html#inheritance # # However, we also want fields defined in the subtype to override ones defined # in the parent. We can implement this via a dict without disrupting the attr order # because dicts preserve insertion order in Python 3.7+. found_fields: dict[str, PydanticModelField] = {} found_class_vars: dict[str, PydanticModelClassVar] = {} for info in reversed(cls.info.mro[1:-1]): # 0 is the current class, -2 is BaseModel, -1 is object # if BASEMODEL_METADATA_TAG_KEY in info.metadata and BASEMODEL_METADATA_KEY not in info.metadata: # # We haven't processed the base class yet. Need another pass. # return None, None if METADATA_KEY not in info.metadata: continue # Each class depends on the set of attributes in its dataclass ancestors. self._api.add_plugin_dependency(make_wildcard_trigger(info.fullname)) for name, data in info.metadata[METADATA_KEY]['fields'].items(): field = PydanticModelField.deserialize(info, data, self._api) # (The following comment comes directly from the dataclasses plugin) # TODO: We shouldn't be performing type operations during the main # semantic analysis pass, since some TypeInfo attributes might # still be in flux. This should be performed in a later phase. field.expand_typevar_from_subtype(cls.info, self._api) found_fields[name] = field sym_node = cls.info.names.get(name) if sym_node and sym_node.node and not isinstance(sym_node.node, Var): self._api.fail( 'BaseModel field may only be overridden by another field', sym_node.node, ) # Collect ClassVars for name, data in info.metadata[METADATA_KEY]['class_vars'].items(): found_class_vars[name] = PydanticModelClassVar.deserialize(data) # Second, collect fields and ClassVars belonging to the current class. current_field_names: set[str] = set() current_class_vars_names: set[str] = set() for stmt in self._get_assignment_statements_from_block(cls.defs): maybe_field = self.collect_field_or_class_var_from_stmt(stmt, model_config, found_class_vars) if isinstance(maybe_field, PydanticModelField): lhs = stmt.lvalues[0] if is_root_model and lhs.name != 'root': error_extra_fields_on_root_model(self._api, stmt) else: current_field_names.add(lhs.name) found_fields[lhs.name] = maybe_field elif isinstance(maybe_field, PydanticModelClassVar): lhs = stmt.lvalues[0] current_class_vars_names.add(lhs.name) found_class_vars[lhs.name] = maybe_field return list(found_fields.values()), list(found_class_vars.values()) def _get_assignment_statements_from_if_statement(self, stmt: IfStmt) -> Iterator[AssignmentStmt]: for body in stmt.body: if not body.is_unreachable: yield from self._get_assignment_statements_from_block(body) if stmt.else_body is not None and not stmt.else_body.is_unreachable: yield from self._get_assignment_statements_from_block(stmt.else_body) def _get_assignment_statements_from_block(self, block: Block) -> Iterator[AssignmentStmt]: for stmt in block.body: if isinstance(stmt, AssignmentStmt): yield stmt elif isinstance(stmt, IfStmt): yield from self._get_assignment_statements_from_if_statement(stmt) def collect_field_or_class_var_from_stmt( # noqa C901 self, stmt: AssignmentStmt, model_config: ModelConfigData, class_vars: dict[str, PydanticModelClassVar] ) -> PydanticModelField | PydanticModelClassVar | None: """Get pydantic model field from statement. Args: stmt: The statement. model_config: Configuration settings for the model. class_vars: ClassVars already known to be defined on the model. Returns: A pydantic model field if it could find the field in statement. Otherwise, `None`. """ cls = self._cls lhs = stmt.lvalues[0] if not isinstance(lhs, NameExpr) or not _fields.is_valid_field_name(lhs.name) or lhs.name == 'model_config': return None if not stmt.new_syntax: if ( isinstance(stmt.rvalue, CallExpr) and isinstance(stmt.rvalue.callee, CallExpr) and isinstance(stmt.rvalue.callee.callee, NameExpr) and stmt.rvalue.callee.callee.fullname in DECORATOR_FULLNAMES ): # This is a (possibly-reused) validator or serializer, not a field # In particular, it looks something like: my_validator = validator('my_field')(f) # Eventually, we may want to attempt to respect model_config['ignored_types'] return None if lhs.name in class_vars: # Class vars are not fields and are not required to be annotated return None # The assignment does not have an annotation, and it's not anything else we recognize error_untyped_fields(self._api, stmt) return None lhs = stmt.lvalues[0] if not isinstance(lhs, NameExpr): return None if not _fields.is_valid_field_name(lhs.name) or lhs.name == 'model_config': return None sym = cls.info.names.get(lhs.name) if sym is None: # pragma: no cover # This is likely due to a star import (see the dataclasses plugin for a more detailed explanation) # This is the same logic used in the dataclasses plugin return None node = sym.node if isinstance(node, PlaceholderNode): # pragma: no cover # See the PlaceholderNode docstring for more detail about how this can occur # Basically, it is an edge case when dealing with complex import logic # The dataclasses plugin now asserts this cannot happen, but I'd rather not error if it does.. return None if isinstance(node, TypeAlias): self._api.fail( 'Type aliases inside BaseModel definitions are not supported at runtime', node, ) # Skip processing this node. This doesn't match the runtime behaviour, # but the only alternative would be to modify the SymbolTable, # and it's a little hairy to do that in a plugin. return None if not isinstance(node, Var): # pragma: no cover # Don't know if this edge case still happens with the `is_valid_field` check above # but better safe than sorry # The dataclasses plugin now asserts this cannot happen, but I'd rather not error if it does.. return None # x: ClassVar[int] is not a field if node.is_classvar: return PydanticModelClassVar(lhs.name) # x: InitVar[int] is not supported in BaseModel node_type = get_proper_type(node.type) if isinstance(node_type, Instance) and node_type.type.fullname == 'dataclasses.InitVar': self._api.fail( 'InitVar is not supported in BaseModel', node, ) has_default = self.get_has_default(stmt) strict = self.get_strict(stmt) if sym.type is None and node.is_final and node.is_inferred: # This follows the logic from the dataclasses plugin. The following comment is taken verbatim: # # This is a special case, assignment like x: Final = 42 is classified # annotated above, but mypy strips the `Final` turning it into x = 42. # We do not support inferred types in dataclasses, so we can try inferring # type for simple literals, and otherwise require an explicit type # argument for Final[...]. typ = self._api.analyze_simple_literal_type(stmt.rvalue, is_final=True) if typ: node.type = typ else: self._api.fail( 'Need type argument for Final[...] with non-literal default in BaseModel', stmt, ) node.type = AnyType(TypeOfAny.from_error) alias, has_dynamic_alias = self.get_alias_info(stmt) if has_dynamic_alias and not model_config.populate_by_name and self.plugin_config.warn_required_dynamic_aliases: error_required_dynamic_aliases(self._api, stmt) is_frozen = self.is_field_frozen(stmt) init_type = self._infer_dataclass_attr_init_type(sym, lhs.name, stmt) return PydanticModelField( name=lhs.name, has_dynamic_alias=has_dynamic_alias, has_default=has_default, strict=strict, alias=alias, is_frozen=is_frozen, line=stmt.line, column=stmt.column, type=init_type, info=cls.info, ) def _infer_dataclass_attr_init_type(self, sym: SymbolTableNode, name: str, context: Context) -> Type | None: """Infer __init__ argument type for an attribute. In particular, possibly use the signature of __set__. """ default = sym.type if sym.implicit: return default t = get_proper_type(sym.type) # Perform a simple-minded inference from the signature of __set__, if present. # We can't use mypy.checkmember here, since this plugin runs before type checking. # We only support some basic scanerios here, which is hopefully sufficient for # the vast majority of use cases. if not isinstance(t, Instance): return default setter = t.type.get('__set__') if setter: if isinstance(setter.node, FuncDef): super_info = t.type.get_containing_type_info('__set__') assert super_info if setter.type: setter_type = get_proper_type(map_type_from_supertype(setter.type, t.type, super_info)) else: return AnyType(TypeOfAny.unannotated) if isinstance(setter_type, CallableType) and setter_type.arg_kinds == [ ARG_POS, ARG_POS, ARG_POS, ]: return expand_type_by_instance(setter_type.arg_types[2], t) else: self._api.fail(f'Unsupported signature for "__set__" in "{t.type.name}"', context) else: self._api.fail(f'Unsupported "__set__" in "{t.type.name}"', context) return default def add_initializer( self, fields: list[PydanticModelField], config: ModelConfigData, is_settings: bool, is_root_model: bool ) -> None: """Adds a fields-aware `__init__` method to the class. The added `__init__` will be annotated with types vs. all `Any` depending on the plugin settings. """ if '__init__' in self._cls.info.names and not self._cls.info.names['__init__'].plugin_generated: return # Don't generate an __init__ if one already exists typed = self.plugin_config.init_typed model_strict = bool(config.strict) use_alias = config.populate_by_name is not True requires_dynamic_aliases = bool(config.has_alias_generator and not config.populate_by_name) args = self.get_field_arguments( fields, typed=typed, model_strict=model_strict, requires_dynamic_aliases=requires_dynamic_aliases, use_alias=use_alias, is_settings=is_settings, is_root_model=is_root_model, force_typevars_invariant=True, ) if is_root_model and MYPY_VERSION_TUPLE <= (1, 0, 1): # convert root argument to positional argument # This is needed because mypy support for `dataclass_transform` isn't complete on 1.0.1 args[0].kind = ARG_POS if args[0].kind == ARG_NAMED else ARG_OPT if is_settings: base_settings_node = self._api.lookup_fully_qualified(BASESETTINGS_FULLNAME).node if '__init__' in base_settings_node.names: base_settings_init_node = base_settings_node.names['__init__'].node if base_settings_init_node is not None and base_settings_init_node.type is not None: func_type = base_settings_init_node.type for arg_idx, arg_name in enumerate(func_type.arg_names): if arg_name.startswith('__') or not arg_name.startswith('_'): continue analyzed_variable_type = self._api.anal_type(func_type.arg_types[arg_idx]) variable = Var(arg_name, analyzed_variable_type) args.append(Argument(variable, analyzed_variable_type, None, ARG_OPT)) if not self.should_init_forbid_extra(fields, config): var = Var('kwargs') args.append(Argument(var, AnyType(TypeOfAny.explicit), None, ARG_STAR2)) add_method(self._api, self._cls, '__init__', args=args, return_type=NoneType()) def add_model_construct_method( self, fields: list[PydanticModelField], config: ModelConfigData, is_settings: bool, is_root_model: bool, ) -> None: """Adds a fully typed `model_construct` classmethod to the class. Similar to the fields-aware __init__ method, but always uses the field names (not aliases), and does not treat settings fields as optional. """ set_str = self._api.named_type(f'{BUILTINS_NAME}.set', [self._api.named_type(f'{BUILTINS_NAME}.str')]) optional_set_str = UnionType([set_str, NoneType()]) fields_set_argument = Argument(Var('_fields_set', optional_set_str), optional_set_str, None, ARG_OPT) with state.strict_optional_set(self._api.options.strict_optional): args = self.get_field_arguments( fields, typed=True, model_strict=bool(config.strict), requires_dynamic_aliases=False, use_alias=False, is_settings=is_settings, is_root_model=is_root_model, ) if not self.should_init_forbid_extra(fields, config): var = Var('kwargs') args.append(Argument(var, AnyType(TypeOfAny.explicit), None, ARG_STAR2)) args = args + [fields_set_argument] if is_root_model else [fields_set_argument] + args add_method( self._api, self._cls, 'model_construct', args=args, return_type=fill_typevars(self._cls.info), is_classmethod=True, ) def set_frozen(self, fields: list[PydanticModelField], api: SemanticAnalyzerPluginInterface, frozen: bool) -> None: """Marks all fields as properties so that attempts to set them trigger mypy errors. This is the same approach used by the attrs and dataclasses plugins. """ info = self._cls.info for field in fields: sym_node = info.names.get(field.name) if sym_node is not None: var = sym_node.node if isinstance(var, Var): var.is_property = frozen or field.is_frozen elif isinstance(var, PlaceholderNode) and not self._api.final_iteration: # See https://github.com/pydantic/pydantic/issues/5191 to hit this branch for test coverage self._api.defer() else: # pragma: no cover # I don't know whether it's possible to hit this branch, but I've added it for safety try: var_str = str(var) except TypeError: # This happens for PlaceholderNode; perhaps it will happen for other types in the future.. var_str = repr(var) detail = f'sym_node.node: {var_str} (of type {var.__class__})' error_unexpected_behavior(detail, self._api, self._cls) else: var = field.to_var(info, api, use_alias=False) var.info = info var.is_property = frozen var._fullname = info.fullname + '.' + var.name info.names[var.name] = SymbolTableNode(MDEF, var) def get_config_update(self, name: str, arg: Expression, lax_extra: bool = False) -> ModelConfigData | None: """Determines the config update due to a single kwarg in the ConfigDict definition. Warns if a tracked config attribute is set to a value the plugin doesn't know how to interpret (e.g., an int) """ if name not in self.tracked_config_fields: return None if name == 'extra': if isinstance(arg, StrExpr): forbid_extra = arg.value == 'forbid' elif isinstance(arg, MemberExpr): forbid_extra = arg.name == 'forbid' else: if not lax_extra: # Only emit an error for other types of `arg` (e.g., `NameExpr`, `ConditionalExpr`, etc.) when # reading from a config class, etc. If a ConfigDict is used, then we don't want to emit an error # because you'll get type checking from the ConfigDict itself. # # It would be nice if we could introspect the types better otherwise, but I don't know what the API # is to evaluate an expr into its type and then check if that type is compatible with the expected # type. Note that you can still get proper type checking via: `model_config = ConfigDict(...)`, just # if you don't use an explicit string, the plugin won't be able to infer whether extra is forbidden. error_invalid_config_value(name, self._api, arg) return None return ModelConfigData(forbid_extra=forbid_extra) if name == 'alias_generator': has_alias_generator = True if isinstance(arg, NameExpr) and arg.fullname == 'builtins.None': has_alias_generator = False return ModelConfigData(has_alias_generator=has_alias_generator) if isinstance(arg, NameExpr) and arg.fullname in ('builtins.True', 'builtins.False'): return ModelConfigData(**{name: arg.fullname == 'builtins.True'}) error_invalid_config_value(name, self._api, arg) return None @staticmethod def get_has_default(stmt: AssignmentStmt) -> bool: """Returns a boolean indicating whether the field defined in `stmt` is a required field.""" expr = stmt.rvalue if isinstance(expr, TempNode): # TempNode means annotation-only, so has no default return False if isinstance(expr, CallExpr) and isinstance(expr.callee, RefExpr) and expr.callee.fullname == FIELD_FULLNAME: # The "default value" is a call to `Field`; at this point, the field has a default if and only if: # * there is a positional argument that is not `...` # * there is a keyword argument named "default" that is not `...` # * there is a "default_factory" that is not `None` for arg, name in zip(expr.args, expr.arg_names): # If name is None, then this arg is the default because it is the only positional argument. if name is None or name == 'default': return arg.__class__ is not EllipsisExpr if name == 'default_factory': return not (isinstance(arg, NameExpr) and arg.fullname == 'builtins.None') return False # Has no default if the "default value" is Ellipsis (i.e., `field_name: Annotation = ...`) return not isinstance(expr, EllipsisExpr) @staticmethod def get_strict(stmt: AssignmentStmt) -> bool | None: """Returns a the `strict` value of a field if defined, otherwise `None`.""" expr = stmt.rvalue if isinstance(expr, CallExpr) and isinstance(expr.callee, RefExpr) and expr.callee.fullname == FIELD_FULLNAME: for arg, name in zip(expr.args, expr.arg_names): if name != 'strict': continue if isinstance(arg, NameExpr): if arg.fullname == 'builtins.True': return True elif arg.fullname == 'builtins.False': return False return None return None @staticmethod def get_alias_info(stmt: AssignmentStmt) -> tuple[str | None, bool]: """Returns a pair (alias, has_dynamic_alias), extracted from the declaration of the field defined in `stmt`. `has_dynamic_alias` is True if and only if an alias is provided, but not as a string literal. If `has_dynamic_alias` is True, `alias` will be None. """ expr = stmt.rvalue if isinstance(expr, TempNode): # TempNode means annotation-only return None, False if not ( isinstance(expr, CallExpr) and isinstance(expr.callee, RefExpr) and expr.callee.fullname == FIELD_FULLNAME ): # Assigned value is not a call to pydantic.fields.Field return None, False for i, arg_name in enumerate(expr.arg_names): if arg_name != 'alias': continue arg = expr.args[i] if isinstance(arg, StrExpr): return arg.value, False else: return None, True return None, False @staticmethod def is_field_frozen(stmt: AssignmentStmt) -> bool: """Returns whether the field is frozen, extracted from the declaration of the field defined in `stmt`. Note that this is only whether the field was declared to be frozen in a ` = Field(frozen=True)` sense; this does not determine whether the field is frozen because the entire model is frozen; that is handled separately. """ expr = stmt.rvalue if isinstance(expr, TempNode): # TempNode means annotation-only return False if not ( isinstance(expr, CallExpr) and isinstance(expr.callee, RefExpr) and expr.callee.fullname == FIELD_FULLNAME ): # Assigned value is not a call to pydantic.fields.Field return False for i, arg_name in enumerate(expr.arg_names): if arg_name == 'frozen': arg = expr.args[i] return isinstance(arg, NameExpr) and arg.fullname == 'builtins.True' return False def get_field_arguments( self, fields: list[PydanticModelField], typed: bool, model_strict: bool, use_alias: bool, requires_dynamic_aliases: bool, is_settings: bool, is_root_model: bool, force_typevars_invariant: bool = False, ) -> list[Argument]: """Helper function used during the construction of the `__init__` and `model_construct` method signatures. Returns a list of mypy Argument instances for use in the generated signatures. """ info = self._cls.info arguments = [ field.to_argument( info, typed=typed, model_strict=model_strict, force_optional=requires_dynamic_aliases or is_settings, use_alias=use_alias, api=self._api, force_typevars_invariant=force_typevars_invariant, is_root_model_root=is_root_model and field.name == 'root', ) for field in fields if not (use_alias and field.has_dynamic_alias) ] return arguments def should_init_forbid_extra(self, fields: list[PydanticModelField], config: ModelConfigData) -> bool: """Indicates whether the generated `__init__` should get a `**kwargs` at the end of its signature. We disallow arbitrary kwargs if the extra config setting is "forbid", or if the plugin config says to, *unless* a required dynamic alias is present (since then we can't determine a valid signature). """ if not config.populate_by_name: if self.is_dynamic_alias_present(fields, bool(config.has_alias_generator)): return False if config.forbid_extra: return True return self.plugin_config.init_forbid_extra @staticmethod def is_dynamic_alias_present(fields: list[PydanticModelField], has_alias_generator: bool) -> bool: """Returns whether any fields on the model have a "dynamic alias", i.e., an alias that cannot be determined during static analysis. """ for field in fields: if field.has_dynamic_alias: return True if has_alias_generator: for field in fields: if field.alias is None: return True return False class ModelConfigData: """Pydantic mypy plugin model config class.""" def __init__( self, forbid_extra: bool | None = None, frozen: bool | None = None, from_attributes: bool | None = None, populate_by_name: bool | None = None, has_alias_generator: bool | None = None, strict: bool | None = None, ): self.forbid_extra = forbid_extra self.frozen = frozen self.from_attributes = from_attributes self.populate_by_name = populate_by_name self.has_alias_generator = has_alias_generator self.strict = strict def get_values_dict(self) -> dict[str, Any]: """Returns a dict of Pydantic model config names to their values. It includes the config if config value is not `None`. """ return {k: v for k, v in self.__dict__.items() if v is not None} def update(self, config: ModelConfigData | None) -> None: """Update Pydantic model config values.""" if config is None: return for k, v in config.get_values_dict().items(): setattr(self, k, v) def setdefault(self, key: str, value: Any) -> None: """Set default value for Pydantic model config if config value is `None`.""" if getattr(self, key) is None: setattr(self, key, value) ERROR_ORM = ErrorCode('pydantic-orm', 'Invalid from_attributes call', 'Pydantic') ERROR_CONFIG = ErrorCode('pydantic-config', 'Invalid config value', 'Pydantic') ERROR_ALIAS = ErrorCode('pydantic-alias', 'Dynamic alias disallowed', 'Pydantic') ERROR_UNEXPECTED = ErrorCode('pydantic-unexpected', 'Unexpected behavior', 'Pydantic') ERROR_UNTYPED = ErrorCode('pydantic-field', 'Untyped field disallowed', 'Pydantic') ERROR_FIELD_DEFAULTS = ErrorCode('pydantic-field', 'Invalid Field defaults', 'Pydantic') ERROR_EXTRA_FIELD_ROOT_MODEL = ErrorCode('pydantic-field', 'Extra field on RootModel subclass', 'Pydantic') def error_from_attributes(model_name: str, api: CheckerPluginInterface, context: Context) -> None: """Emits an error when the model does not have `from_attributes=True`.""" api.fail(f'"{model_name}" does not have from_attributes=True', context, code=ERROR_ORM) def error_invalid_config_value(name: str, api: SemanticAnalyzerPluginInterface, context: Context) -> None: """Emits an error when the config value is invalid.""" api.fail(f'Invalid value for "Config.{name}"', context, code=ERROR_CONFIG) def error_required_dynamic_aliases(api: SemanticAnalyzerPluginInterface, context: Context) -> None: """Emits required dynamic aliases error. This will be called when `warn_required_dynamic_aliases=True`. """ api.fail('Required dynamic aliases disallowed', context, code=ERROR_ALIAS) def error_unexpected_behavior( detail: str, api: CheckerPluginInterface | SemanticAnalyzerPluginInterface, context: Context ) -> None: # pragma: no cover """Emits unexpected behavior error.""" # Can't think of a good way to test this, but I confirmed it renders as desired by adding to a non-error path link = 'https://github.com/pydantic/pydantic/issues/new/choose' full_message = f'The pydantic mypy plugin ran into unexpected behavior: {detail}\n' full_message += f'Please consider reporting this bug at {link} so we can try to fix it!' api.fail(full_message, context, code=ERROR_UNEXPECTED) def error_untyped_fields(api: SemanticAnalyzerPluginInterface, context: Context) -> None: """Emits an error when there is an untyped field in the model.""" api.fail('Untyped fields disallowed', context, code=ERROR_UNTYPED) def error_extra_fields_on_root_model(api: CheckerPluginInterface, context: Context) -> None: """Emits an error when there is more than just a root field defined for a subclass of RootModel.""" api.fail('Only `root` is allowed as a field of a `RootModel`', context, code=ERROR_EXTRA_FIELD_ROOT_MODEL) def error_default_and_default_factory_specified(api: CheckerPluginInterface, context: Context) -> None: """Emits an error when `Field` has both `default` and `default_factory` together.""" api.fail('Field default and default_factory cannot be specified together', context, code=ERROR_FIELD_DEFAULTS) def add_method( api: SemanticAnalyzerPluginInterface | CheckerPluginInterface, cls: ClassDef, name: str, args: list[Argument], return_type: Type, self_type: Type | None = None, tvar_def: TypeVarDef | None = None, is_classmethod: bool = False, ) -> None: """Very closely related to `mypy.plugins.common.add_method_to_class`, with a few pydantic-specific changes.""" info = cls.info # First remove any previously generated methods with the same name # to avoid clashes and problems in the semantic analyzer. if name in info.names: sym = info.names[name] if sym.plugin_generated and isinstance(sym.node, FuncDef): cls.defs.body.remove(sym.node) # pragma: no cover if isinstance(api, SemanticAnalyzerPluginInterface): function_type = api.named_type('builtins.function') else: function_type = api.named_generic_type('builtins.function', []) if is_classmethod: self_type = self_type or TypeType(fill_typevars(info)) first = [Argument(Var('_cls'), self_type, None, ARG_POS, True)] else: self_type = self_type or fill_typevars(info) # `self` is positional *ONLY* here, but this can't be expressed # fully in the mypy internal API. ARG_POS is the closest we can get. # Using ARG_POS will, however, give mypy errors if a `self` field # is present on a model: # # Name "self" already defined (possibly by an import) [no-redef] # # As a workaround, we give this argument a name that will # never conflict. By its positional nature, this name will not # be used or exposed to users. first = [Argument(Var('__pydantic_self__'), self_type, None, ARG_POS)] args = first + args arg_types, arg_names, arg_kinds = [], [], [] for arg in args: assert arg.type_annotation, 'All arguments must be fully typed.' arg_types.append(arg.type_annotation) arg_names.append(arg.variable.name) arg_kinds.append(arg.kind) signature = CallableType(arg_types, arg_kinds, arg_names, return_type, function_type) if tvar_def: signature.variables = [tvar_def] func = FuncDef(name, args, Block([PassStmt()])) func.info = info func.type = set_callable_name(signature, func) func.is_class = is_classmethod func._fullname = info.fullname + '.' + name func.line = info.line # NOTE: we would like the plugin generated node to dominate, but we still # need to keep any existing definitions so they get semantically analyzed. if name in info.names: # Get a nice unique name instead. r_name = get_unique_redefinition_name(name, info.names) info.names[r_name] = info.names[name] # Add decorator for is_classmethod # The dataclasses plugin claims this is unnecessary for classmethods, but not including it results in a # signature incompatible with the superclass, which causes mypy errors to occur for every subclass of BaseModel. if is_classmethod: func.is_decorated = True v = Var(name, func.type) v.info = info v._fullname = func._fullname v.is_classmethod = True dec = Decorator(func, [NameExpr('classmethod')], v) dec.line = info.line sym = SymbolTableNode(MDEF, dec) else: sym = SymbolTableNode(MDEF, func) sym.plugin_generated = True info.names[name] = sym info.defn.defs.body.append(func) def parse_toml(config_file: str) -> dict[str, Any] | None: """Returns a dict of config keys to values. It reads configs from toml file and returns `None` if the file is not a toml file. """ if not config_file.endswith('.toml'): return None if sys.version_info >= (3, 11): import tomllib as toml_ else: try: import tomli as toml_ except ImportError: # pragma: no cover import warnings warnings.warn('No TOML parser installed, cannot read configuration from `pyproject.toml`.') return None with open(config_file, 'rb') as rf: return toml_.load(rf)