深入探索联邦学习框架 Flower

A deep dive into the federated learning framework Flower

Posted by Bryan on November 16, 2023

联邦学习框架

本文主要期望介绍一个设计良好的联邦学习框架 Flower,在开始介绍 Flower 框架的细节前,先了解下联邦学习框架的基础知识。

作为一个联邦学习框架,必然会包含对横向联邦学习的支持。横向联邦是指拥有类似数据的多方可以在不泄露数据的情况下联合训练出一个模型,这个模型可以充分利用各方的数据,接近将全部数据集中在一起进行训练的效果。横向联邦学习的一般流程如下:

federated

横向联邦学习的过程简单理解如下:

  1. 各个参与方基于自身的数据训练出本地模型,将模型参数发送给公共的服务端;
  2. 服务端将收到的多个模型参数聚合生成全局的模型参数,然后下发给各个参与方;
  3. 参与方使用全局的模型参数更新本地模型,重复这个步骤直到模型训练达到要求;

从上面的过程可以看到,作为一个联邦学习框架,需要关注下面要点:

  1. 参与方本地模型训练;
  2. 模型参数的传输;
  3. 模型的聚合策略;

Flower 框架上手

Flower 是一个轻量的联邦学习框架,提出于 2020 年。一直以来,因为设计良好,方便扩展受到了比较多的关注。团队通过论文 FLOWER: A FRIENDLY FEDERATED LEARNING FRAMEWORK 介绍了框架的设计思想。通过论文可以看到框架设计主要追求下面目标:

  1. 可拓展,支持大量的客户端同时进行模型训练;
  2. 使用灵活,支持异构的客户端,通信协议,隐私策略,支持新功能的开销小;

首先基于 Flower 框架实际进行了一个机器学习的模型训练,通过实际动手可以感受下基于 Flower 框架可以用相当简单的方式实现一个联邦学习模型训练流程。这个流程是参考 Flower Quickstart 实现的。

常规机器学习部分实现

首先实现机器学习训练所需的基础方法,主要是数据集的准备,定义所需的模型,封装训练与测试流程,这部分与联邦学习无关。熟悉 pytorch 的应该很容易理解这部分代码:

from collections import OrderedDict

import flwr as fl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, Normalize, ToTensor
from tqdm import tqdm


DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 定义神经网络模型

class Net(nn.Module):
    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

# 定义模型训练流程

def train(net, trainloader, epochs):
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    for _ in range(epochs):
        for images, labels in tqdm(trainloader):
            optimizer.zero_grad()
            criterion(net(images.to(DEVICE)), labels.to(DEVICE)).backward()
            optimizer.step()

# 定义模型推理流程

def test(net, testloader):
    criterion = torch.nn.CrossEntropyLoss()
    correct, loss = 0, 0.0
    with torch.no_grad():
        for images, labels in tqdm(testloader):
            outputs = net(images.to(DEVICE))
            labels = labels.to(DEVICE)
            loss += criterion(outputs, labels).item()
            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
    accuracy = correct / len(testloader.dataset)
    return loss, accuracy

# 定义数据集的获取

def load_data():
    """Load CIFAR-10 (training and test set)."""
    trf = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    trainset = CIFAR10("./data", train=True, download=True, transform=trf)
    testset = CIFAR10("./data", train=False, download=True, transform=trf)
    return DataLoader(trainset, batch_size=32, shuffle=True), DataLoader(testset)

# 生成模型对象,实际获取训练与测试数据集

net = Net().to(DEVICE)
trainloader, testloader = load_data()

如果是常规的机器学习模型训练,直接调用上面的 train(net, trainloader, epochs=1) 即可完成模型的训练,但是此时各方只能使用本地有限的数据训练出一个机器学习模型。如果希望能充分利用各方的数据训练出一个更好的机器学习模型,就需要基于 Flower 补充联邦学习的能力了。

Flower 客户端实现

Flower 的客户端表示的是各个参与联合训练的一方,Flower 客户端会完成本地的模型训练,并将本地训练的模型发送给服务端,然后接收服务端下发的聚合模型。需要实现如下:

class FlowerClient(fl.client.NumPyClient):
    # 获取本地模型对应的参数

    def get_parameters(self, config):
        return [val.cpu().numpy() for _, val in net.state_dict().items()]

    # 接收模型参数,并更新本地模型

    def set_parameters(self, parameters):
        params_dict = zip(net.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        net.load_state_dict(state_dict, strict=True)

    # 本地模型训练,会先调用 set_parameters() 基于收到的全局模型参数更新本地模型

    def fit(self, parameters, config):
        self.set_parameters(parameters)
        train(net, trainloader, epochs=1)
        return self.get_parameters(config={}), len(trainloader.dataset), {}

    # 基于测试数据集进行测试

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        loss, accuracy = test(net, testloader)
        return loss, len(testloader.dataset), {"accuracy": accuracy}

# 启动 Flower 客户端

fl.client.start_numpy_client(
    server_address="127.0.0.1:8080",
    client=FlowerClient(),
)

可以看到 Flower 客户端的训练流程中,除了实现 fit() 方法进行本地模型的训练之外,只需要额外实现两个方法,使用 get_parameters() 获取了本地模型的参数,方便发送本地模型参数给服务端用于模型聚合,使用 set_parameters() 方便根据服务端发来的聚合模型参数更新本地模型。

Flower 服务端实现

Flower 服务端主要用于接收客户端发来的模型参数,聚合参数后下发给客户端,具体实现如下:

from typing import List, Tuple

import flwr as fl
from flwr.common import Metrics

# 定义指标聚合方法

def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
    accuracies = [num_examples * m["accuracy"] for num_examples, m in metrics]
    examples = [num_examples for num_examples, _ in metrics]

    return {"accuracy": sum(accuracies) / sum(examples)}

# 定义模型聚合策略

strategy = fl.server.strategy.FedAvg(evaluate_metrics_aggregation_fn=weighted_average)

# 启动 Flower 服务端

fl.server.start_server(
    server_address="0.0.0.0:8080",
    config=fl.server.ServerConfig(num_rounds=3),
    strategy=strategy,
)

可以看到 Flower 服务端只需要定义模型聚合的策略,由于机器学习的主要流程是固定的,因此不需要手工实现。

通过上面的实现即可完成联邦学习模型训练,实际需要先启动 Flower 服务端,然后可以根据需要启动多个 Flower 客户端,即可进行联邦学习模型训练。

Flower 框架设计

Flower 框架是如何设计从而实现上面的便利使用的呢?可以先看看官方的架构图:

flower

上面的框架中 RPC Server 就是服务端,作为中心节点协调大量的客户端进行联合的模型训练。服务端中包含三大核心组件:

  1. ClientManager,用于管理现有的客户端,提供 ClientProxy 作为客户端的抽象,方便统一处理,利用 ClientProxy 进行实际数据的发送,基于 ClientProxy 消除了客户端的异构性;
  2. Strategy,用于确定每个阶段如何执行。比如在训练阶段,通过 Strategy 确定采取什么样的模型聚合策略。Strategy 是可以由用户自定义的;
  3. FL loop,定义整个执行机器学习流程,因为机器学习的训练与预测流程都是固定的,因此根据确定的任务即可确定对应的 FL loop 流程;

Flower 源码实现

通过上面的框架我们可以理解 Flower 是如何实现灵活性,接下来我们可以关注具体的源码实现,具体了解 Flower 框架如何将我们实现的客户端与服务端串联起来,并了解 Flower 框架的支撑组件的具体信息。

Flower 客户端

Flower 的客户端的基础类 NumPyClient 主要定义了一些必要的方法,这些方法需要在子类中实现,这个在上面的动手阶段已经体现出来,没有太多需要深入了解的。我们主要关注下客户端的启动方法 start_numpy_client(),这部分的实现主要在 src/py/flwr/client/app.py 中的 start_client() 方法中完成,具体的实现如下:

def start_client(
    *,
    server_address: str,
    client_fn: Optional[ClientFn] = None,
    client: Optional[Client] = None,
    grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
    root_certificates: Optional[Union[bytes, str]] = None,
    transport: Optional[str] = None,
) -> None:

    # 初始化连接上下文

    connection, address = _init_connection(transport, server_address)

    while True:
        sleep_duration: int = 0
        # 建立 RPC 连接

        with connection(
            address,
            grpc_max_message_length,
            root_certificates,
        ) as conn:
            receive, send, create_node, delete_node = conn

            # 注册当前客户端

            if create_node is not None:
                create_node()

            while True:
                # 接收服务端发来的消息

                task_ins = receive()
                if task_ins is None:
                    time.sleep(3)
                    continue

                # 处理系统控制类消息,处理完控制类消息就退出

                task_res, sleep_duration = handle_control_message(task_ins=task_ins)
                if task_res:
                    send(task_res)
                    break

                # 处理普通任务中的消息

                task_res = handle(client_fn, task_ins)

                # 将执行结果发送给服务端

                send(task_res)

            # 注销当前客户端

            if delete_node is not None:
                delete_node()

        if sleep_duration == 0:
            break

        time.sleep(sleep_duration)

可以看到整体的实现流程还是比较简单的,客户端扮演一个被动的角色,通过 receive() 持续接收服务端发来的消息,系统消息通过 handle_control_message() 方法处理, 常规消息通过 handle() 方法进行处理,然后将处理的结果通过 send() 方法发送给服务端。

系统消息只包含一种 reconnect_ins 消息,主要用于客户端重连或断开,目前主要常规消息的处理,对应的处理方法 handle() 实现如下所示:

def handle(client_fn: ClientFn, task_ins: TaskIns) -> TaskRes:
    server_msg = get_server_message_from_task_ins(task_ins, exclude_reconnect_ins=False)

    # 临时支持安全聚合的分支

    if server_msg is None:
        client = client_fn("-1")
        if task_ins.task.HasField("sa") and isinstance(
            client, SecureAggregationHandler
        ):
            named_values = serde.named_values_from_proto(task_ins.task.sa.named_values)
            res = client.handle_secure_aggregation(named_values)
            task_res = TaskRes(
                task_id="",
                group_id="",
                workload_id=0,
                task=Task(
                    ancestry=[],
                    sa=SecureAggregation(named_values=serde.named_values_to_proto(res)),
                ),
            )
            return task_res
        raise NotImplementedError()

    # 处理消息

    client_msg = handle_legacy_message(client_fn, server_msg)

    # 处理结果的封装

    task_res = wrap_client_message_in_task_res(client_msg)
    return task_res

def handle_legacy_message(
    client_fn: ClientFn, server_msg: ServerMessage
) -> ClientMessage:
    field = server_msg.WhichOneof("msg")

    client = client_fn("-1")

    if field == "get_properties_ins":
        return _get_properties(client, server_msg.get_properties_ins)

    # 获取客户端的模型参数

    if field == "get_parameters_ins":
        return _get_parameters(client, server_msg.get_parameters_ins)

    # 客户端模型训练

    if field == "fit_ins":
        return _fit(client, server_msg.fit_ins)

    # 客户端基于测试集进行测试

    if field == "evaluate_ins":
        return _evaluate(client, server_msg.evaluate_ins)
    raise UnknownServerMessage()

上面的消息处理最终调用的就是客户端需要实现的 fit()get_parameters() 等方法。可以看到客户端的实现还是比较简单的,只是根据几种消息,执行预先支持的对应的方法即可。考虑到客户端完全是被动响应服务端的消息,因此主要的联邦学习的流程的支持都是由服务端定义好的。

Flower 服务端

Flower 服务端需要自定义的代码较少,主流程是通过直接调用 src/py/flwr/server/app.py 中的 start_server()方法完成的,我们可以了解下对应的实现:

def start_server(
    *,
    server_address: str = ADDRESS_FLEET_API_GRPC_BIDI,
    server: Optional[Server] = None,
    config: Optional[ServerConfig] = None,
    strategy: Optional[Strategy] = None,
    client_manager: Optional[ClientManager] = None,
    grpc_max_message_length: int = GRPC_MAX_MESSAGE_LENGTH,
    certificates: Optional[Tuple[bytes, bytes, bytes]] = None,
) -> History:

    # 构造启动地址

    parsed_address = parse_address(server_address)
    if not parsed_address:
        sys.exit(f"Server IP address ({server_address}) cannot be parsed.")
    host, port, is_v6 = parsed_address
    address = f"[{host}]:{port}" if is_v6 else f"{host}:{port}"

    # 初始化 Server 对象,此对象中包含实际的模型训练流程的支持

    initialized_server, initialized_config = init_defaults(
        server=server,
        config=config,
        strategy=strategy,
        client_manager=client_manager,
    )

    # 启动 grpc 服务端,用于与客户端进行通信

    grpc_server = start_grpc_server(
        client_manager=initialized_server.client_manager(),
        server_address=address,
        max_message_length=grpc_max_message_length,
        certificates=certificates,
    )

    # 执行训练流程

    hist = run_fl(
        server=initialized_server,
        config=initialized_config,
    )

    # 停止 grpc 服务端

    grpc_server.stop(grace=1)

    return hist

执行训练流程的方法 run_fl(initialized_server) 事实上就是调用 initialized_server.fit() 方法,主要的训练流程都是在 src/py/flwr/server/server.py 中的 Server.fit() 实现的,下面就重点关注这边的实现:

def fit(self, num_rounds: int, timeout: Optional[float]) -> History:
    # 初始化全局模型参数

    self.parameters = self._get_initial_parameters(timeout=timeout)

    # 执行 num_rounds 轮模型训练

    for current_round in range(1, num_rounds + 1):
        # 执行单轮机器学习模型训练

        res_fit = self.fit_round(
            server_round=current_round,
            timeout=timeout,
        )
        if res_fit is not None:
            parameters_prime, fit_metrics, _ = res_fit
            # 根据聚合生成的模型参数更新全局模型参数

            if parameters_prime:
                self.parameters = parameters_prime

可以看到单轮的联邦学习模型训练就是通过 fit_round() 实现,具体的代码如下所示:

def fit_round(
    self,
    server_round: int,
    timeout: Optional[float],
) -> Optional[
    Tuple[Optional[Parameters], Dict[str, Scalar], FitResultsAndFailures]
]:
    client_instructions = self.strategy.configure_fit(
        server_round=server_round,
        parameters=self.parameters,
        client_manager=self._client_manager,
    )

    # 客户端基于本地数据进行模型训练

    results, failures = fit_clients(
        client_instructions=client_instructions,
        max_workers=self.max_workers,
        timeout=timeout,
    )

    # 聚合客户端发来的模型参数

    aggregated_result: Tuple[
        Optional[Parameters],
        Dict[str, Scalar],
    ] = self.strategy.aggregate_fit(server_round, results, failures)

    parameters_aggregated, metrics_aggregated = aggregated_result
    return parameters_aggregated, metrics_aggregated, (results, failures)

下面具体 fit_clients() 是如何发起模型训练的:

def fit_clients(
    client_instructions: List[Tuple[ClientProxy, FitIns]],
    max_workers: Optional[int],
    timeout: Optional[float],
) -> FitResultsAndFailures:
    # 多线程并发调用 fit_client 方法实现客户端模型训练

    with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
        submitted_fs = {
            executor.submit(fit_client, client_proxy, ins, timeout)
            for client_proxy, ins in client_instructions
        }
        finished_fs, _ = concurrent.futures.wait(
            fs=submitted_fs,
            timeout=None,
        )

    results: List[Tuple[ClientProxy, FitRes]] = []
    failures: List[Union[Tuple[ClientProxy, FitRes], BaseException]] = []
    for future in finished_fs:
        _handle_finished_future_after_fit(
            future=future, results=results, failures=failures
        )
    return results, failures

# 通过 ClientProxy 发起客户端的 fit() 模型训练,ins 中包含全局模型参数

def fit_client(
    client: ClientProxy, ins: FitIns, timeout: Optional[float]
) -> Tuple[ClientProxy, FitRes]:
    fit_res = client.fit(ins, timeout=timeout)
    return client, fit_res

总结而言:服务端就是通过多次调用 fit_round() 方法实现多轮的联邦学习模型训练,在单轮的联邦学习模型训练中,客户端会根据全局模型参数更新本地模型参数,然后进行本地的模型训练,并将训练后的模型的参数发回给服务端,服务端会根据 Strategy 策略聚合客户端发来的模型参数,然后更新服务端的全局模型参数。

总结

通过上面的内容,将 Flower 框架的动手实践以及对应的实现细节都介绍到了,主要涉及到 Flower 三大核心组件中的 Strategy 与 FL loop,而 ClientManager 目前没有过多展开,这部分主要用于管理客户端的连接,有兴趣的可以自行去探索下。

从目前来看,Flower 基本上是一个最精简的横向联邦学习的实现方案了,通过必要的抽象简化,Flower 将横向联邦用简单易用的方式进行了封装,对于了解横向联邦学习有很大的帮助。而且值得一提的是,Flower 本身具备比较好的灵活性,可以比较方便地支持不同联邦学习策略,因为 Proxy 机制的存在也能灵活支持异构的客户端,对于设计一个高效的联邦学习框架有不少的借鉴意义。