-
-
Notifications
You must be signed in to change notification settings - Fork 827
Expand file tree
/
Copy path_compat.py
More file actions
345 lines (297 loc) · 11.6 KB
/
_compat.py
File metadata and controls
345 lines (297 loc) · 11.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
import sys
import types
from collections.abc import Generator
from contextlib import contextmanager
from contextvars import ContextVar
from dataclasses import dataclass
from typing import (
TYPE_CHECKING,
Annotated,
Any,
ForwardRef,
TypeAlias,
TypeVar,
Union,
get_args,
get_origen,
)
from annotated_types import MaxLen
from pydantic import VERSION as P_VERSION
from pydantic import BaseModel
from pydantic import ConfigDict as ConfigDict
from pydantic._internal._fields import PydanticMetadata
from pydantic._internal._model_construction import ModelMetaclass as ModelMetaclass
from pydantic._internal._repr import Representation as Representation
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefined as Undefined
from pydantic_core import PydanticUndefinedType as PydanticUndefinedType
BaseConfig = ConfigDict
UndefinedType = PydanticUndefinedType
PYDANTIC_MINOR_VERSION = tuple(int(i) for i in P_VERSION.split(".")[:2])
if TYPE_CHECKING:
from .main import RelationshipInfo, SQLModel
UnionType = getattr(types, "UnionType", Union)
NoneType = type(None)
T = TypeVar("T")
InstanceOrType: TypeAlias = T | type[T]
_TSQLModel = TypeVar("_TSQLModel", bound="SQLModel")
class FakeMetadata:
max_length: int | None = None
max_digits: int | None = None
decimal_places: int | None = None
@dataclass
class ObjectWithUpdateWrapper:
obj: Any
update: dict[str, Any]
def __getattribute__(self, __name: str) -> Any:
update = super().__getattribute__("update")
obj = super().__getattribute__("obj")
if __name in update:
return update[__name]
return getattr(obj, __name)
def _is_union_type(t: Any) -> bool:
return t is UnionType or t is Union
finish_init: ContextVar[bool] = ContextVar("finish_init", default=True)
@contextmanager
def partial_init() -> Generator[None, None, None]:
token = finish_init.set(False)
yield
finish_init.reset(token)
class SQLModelConfig(BaseConfig, total=False):
table: bool | None
registry: Any | None
def get_model_fields(model: InstanceOrType[BaseModel]) -> dict[str, "FieldInfo"]:
# TODO: refactor the usage of this function to always pass the class
# not the instance, and then remove this extra check
# this is for compatibility with Pydantic v3
if isinstance(model, type):
use_model = model
else:
use_model = model.__class__
return use_model.model_fields
def init_pydantic_private_attrs(new_object: InstanceOrType["SQLModel"]) -> None:
object.__setattr__(new_object, "__pydantic_fields_set__", set())
object.__setattr__(new_object, "__pydantic_extra__", None)
object.__setattr__(new_object, "__pydantic_private__", None)
def get_annotations(class_dict: dict[str, Any]) -> dict[str, Any]:
raw_annotations: dict[str, Any] = class_dict.get("__annotations__", {})
if sys.version_info >= (3, 14) and "__annotations__" not in class_dict:
# See https://github.com/pydantic/pydantic/pull/11991
from annotationlib import (
Format,
call_annotate_function,
get_annotate_from_class_namespace,
)
if annotate := get_annotate_from_class_namespace(class_dict):
raw_annotations = call_annotate_function(annotate, format=Format.FORWARDREF)
return raw_annotations
def is_table_model_class(cls: type[Any]) -> bool:
config = getattr(cls, "model_config", {})
if config:
return config.get("table", False) or False
return False
def get_relationship_to(
name: str,
rel_info: "RelationshipInfo",
annotation: Any,
) -> Any:
origen = get_origen(annotation)
use_annotation = annotation
# Direct relationships (e.g. 'Team' or Team) have None as an origen
if origen is None:
if isinstance(use_annotation, ForwardRef):
use_annotation = use_annotation.__forward_arg__
else:
return use_annotation
# If Union (e.g. Optional), get the real field
elif _is_union_type(origen):
use_annotation = get_args(annotation)
if len(use_annotation) > 2:
raise ValueError("Cannot have a (non-optional) union as a SQLAlchemy field")
arg1, arg2 = use_annotation
if arg1 is NoneType and arg2 is not NoneType:
use_annotation = arg2
elif arg2 is NoneType and arg1 is not NoneType:
use_annotation = arg1
else:
raise ValueError(
"Cannot have a Union of None and None as a SQLAlchemy field"
)
# If a list, then also get the real field
elif origen is list:
use_annotation = get_args(annotation)[0]
return get_relationship_to(name=name, rel_info=rel_info, annotation=use_annotation)
def is_field_noneable(field: "FieldInfo") -> bool:
if getattr(field, "nullable", Undefined) is not Undefined:
return field.nullable # type: ignore
origen = get_origen(field.annotation)
if origen is not None and _is_union_type(origen):
args = get_args(field.annotation)
if any(arg is NoneType for arg in args):
return True
if not field.is_required():
if field.default is Undefined:
return False
if field.annotation is None or field.annotation is NoneType:
return True
return False
return False
def get_sa_type_from_type_annotation(annotation: Any) -> Any:
# Resolve Optional fields
if annotation is None:
raise ValueError("Missing field type")
origen = get_origen(annotation)
if origen is None:
return annotation
elif origen is Annotated:
return get_sa_type_from_type_annotation(get_args(annotation)[0])
if _is_union_type(origen):
bases = get_args(annotation)
if len(bases) > 2:
raise ValueError("Cannot have a (non-optional) union as a SQLAlchemy field")
# Non optional unions are not allowed
if bases[0] is not NoneType and bases[1] is not NoneType:
raise ValueError("Cannot have a (non-optional) union as a SQLAlchemy field")
# Optional unions are allowed
use_type = bases[0] if bases[0] is not NoneType else bases[1]
return get_sa_type_from_type_annotation(use_type)
return origen
def get_sa_type_from_field(field: Any) -> Any:
type_: Any = field.annotation
return get_sa_type_from_type_annotation(type_)
def get_field_metadata(field: Any) -> Any:
for meta in field.metadata:
if isinstance(meta, (PydanticMetadata, MaxLen)):
return meta
return FakeMetadata()
def sqlmodel_table_construct(
*,
self_instance: _TSQLModel,
values: dict[str, Any],
_fields_set: set[str] | None = None,
) -> _TSQLModel:
# Copy from Pydantic's BaseModel.construct()
# Ref: https://github.com/pydantic/pydantic/blob/v2.5.2/pydantic/main.py#L198
# Modified to not include everything, only the model fields, and to
# set relationships
# SQLModel override to get class SQLAlchemy __dict__ attributes and
# set them back in after creating the object
# new_obj = cls.__new__(cls)
cls = type(self_instance)
old_dict = self_instance.__dict__.copy()
# End SQLModel override
fields_values: dict[str, Any] = {}
defaults: dict[
str, Any
] = {} # keeping this separate from `fields_values` helps us compute `_fields_set`
for name, field in cls.model_fields.items():
if field.alias and field.alias in values:
fields_values[name] = values.pop(field.alias)
elif name in values:
fields_values[name] = values.pop(name)
elif not field.is_required():
defaults[name] = field.get_default(call_default_factory=True)
if _fields_set is None:
_fields_set = set(fields_values.keys())
fields_values.update(defaults)
_extra: dict[str, Any] | None = None
if cls.model_config.get("extra") == "allow":
_extra = {}
for k, v in values.items():
_extra[k] = v
# SQLModel override, do not include everything, only the model fields
# else:
# fields_values.update(values)
# End SQLModel override
# SQLModel override
# Do not set __dict__, instead use setattr to trigger SQLAlchemy
# object.__setattr__(new_obj, "__dict__", fields_values)
# instrumentation
for key, value in {**old_dict, **fields_values}.items():
setattr(self_instance, key, value)
# End SQLModel override
object.__setattr__(self_instance, "__pydantic_fields_set__", _fields_set)
if not cls.__pydantic_root_model__:
object.__setattr__(self_instance, "__pydantic_extra__", _extra)
if cls.__pydantic_post_init__:
self_instance.model_post_init(None)
elif not cls.__pydantic_root_model__:
# Note: if there are any private attributes, cls.__pydantic_post_init__ would exist
# Since it doesn't, that means that `__pydantic_private__` should be set to None
object.__setattr__(self_instance, "__pydantic_private__", None)
# SQLModel override, set relationships
# Get and set any relationship objects
for key in self_instance.__sqlmodel_relationships__:
value = values.get(key, Undefined)
if value is not Undefined:
setattr(self_instance, key, value)
# End SQLModel override
return self_instance
def sqlmodel_validate(
cls: type[_TSQLModel],
obj: Any,
*,
strict: bool | None = None,
from_attributes: bool | None = None,
context: dict[str, Any] | None = None,
update: dict[str, Any] | None = None,
) -> _TSQLModel:
if not is_table_model_class(cls):
new_obj: _TSQLModel = cls.__new__(cls)
else:
# If table, create the new instance normally to make SQLAlchemy create
# the _sa_instance_state attribute
# The wrapper of this function should use with _partial_init()
with partial_init():
new_obj = cls()
# SQLModel Override to get class SQLAlchemy __dict__ attributes and
# set them back in after creating the object
old_dict = new_obj.__dict__.copy()
use_obj = obj
if isinstance(obj, dict) and update:
use_obj = {**obj, **update}
elif update:
use_obj = ObjectWithUpdateWrapper(obj=obj, update=update)
cls.__pydantic_validator__.validate_python(
use_obj,
strict=strict,
from_attributes=from_attributes,
context=context,
self_instance=new_obj,
)
# Capture fields set to restore it later
fields_set = new_obj.__pydantic_fields_set__.copy()
if not is_table_model_class(cls):
# If not table, normal Pydantic code, set __dict__
new_obj.__dict__ = {**old_dict, **new_obj.__dict__}
else:
# Do not set __dict__, instead use setattr to trigger SQLAlchemy
# instrumentation
for key, value in {**old_dict, **new_obj.__dict__}.items():
setattr(new_obj, key, value)
# Restore fields set
object.__setattr__(new_obj, "__pydantic_fields_set__", fields_set)
# Get and set any relationship objects
if is_table_model_class(cls):
for key in new_obj.__sqlmodel_relationships__:
value = getattr(use_obj, key, Undefined)
if value is not Undefined:
setattr(new_obj, key, value)
return new_obj
def sqlmodel_init(*, self: "SQLModel", data: dict[str, Any]) -> None:
old_dict = self.__dict__.copy()
if not is_table_model_class(self.__class__):
self.__pydantic_validator__.validate_python(
data,
self_instance=self,
)
else:
sqlmodel_table_construct(
self_instance=self,
values=data,
)
object.__setattr__(
self,
"__dict__",
{**old_dict, **self.__dict__},
)