python3 依赖注入

TLDR

fastapi 中会有以下的写法

1
2
3
4
5
6
@router("/user")
async def read_user(
username: str,
user: Annotated[str, Depends(get_user)],
db: Annotated[str, Depends(get_db)]
):

然后这里的实现核心 是靠decorator + inspect

注解decorator 负责对函数重写 包裹, 创建一些闭包内的变量

而inspect可以读取函数的各种配置, 从而完成注入

Demo

下面的 函数 变为 注解里的 wrapper, 这里如果需要跨域多个请求的Context, 只需要在wrapper外定义另一个Context

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import inspect
from functools import wraps
from typing import Annotated, Any, Callable, TypeVar, get_args, get_origin
import asyncio


class Depends:
def __init__(self, dependency: Callable):
self.dependency = dependency


class RequestContext:
def __init__(self):
self.cache = {}

async def get_dependency(self, dep: Callable, **kwargs) -> Any:
if (dep, tuple(kwargs.items())) not in self.cache:
if asyncio.iscoroutinefunction(dep):
self.cache[(dep, tuple(kwargs.items()))] = await dep( **kwargs)
else:
self.cache[(dep, tuple(kwargs.items()))] = dep( **kwargs)
return self.cache[(dep, tuple(kwargs.items()))]


async def resolve_dependencies(func: Callable, context: RequestContext, *func_args, **func_kwargs) -> dict:
sig = inspect.signature(func)
kwargs = {}
bound_args = sig.bind_partial(*func_args, **func_kwargs) # tuple 对于dict键的顺序是基于插入的,这里func_kwargs是与传递有关的,而这里在解析arg的同时让键按照顺序排列
for param_name, param_value in bound_args.arguments.items():
if param_name in func_kwargs:
kwargs[param_name] = func_kwargs[param_name]
elif param_name in bound_args.arguments:
kwargs[param_name] = bound_args.arguments[param_name]
else:
kwargs[param_name] = None


for param_name, param in sig.parameters.items():

if get_origin(param.annotation) is Annotated:
dep_type, dep = get_args(param.annotation)
if isinstance(dep, Depends):
dep_sig = inspect.signature(dep.dependency)
dep_params = dep_sig.parameters

dep_kwargs = {k: v for k, v in kwargs.items() if k in dep_params}
kwargs[param_name] = await context.get_dependency(dep.dependency, **dep_kwargs)
return kwargs


def route_decorator(path: str):
def decorator(func: Callable):
@wraps(func)
async def wrapper(*args, **kwargs):
context = RequestContext()
resolved_kwargs = await resolve_dependencies(func, context, *args, **kwargs)
if asyncio.iscoroutinefunction(func):
return await func(**resolved_kwargs)
return func(**resolved_kwargs)
return wrapper
return decorator



async def get_user(username: str, nick:str) -> str:
return f"User-{username}-{nick}"

def get_db() -> str:
return "Database Connection"


@route_decorator("/user")
async def read_user(
username: str,
nick:str,
user: Annotated[str, Depends(get_user)],
db: Annotated[str, Depends(get_db)]
):
return {"user": user, "db": db}


async def main():
result = await read_user(username="Alice",nick="123")
print(result)
result = await read_user(nick="123",username="Alice")
print(result)
result = await read_user("Bob",nick="Bober")
print(result)

if __name__ == "__main__":
asyncio.run(main())