PyTorch Participant Example

This is an example of a PyTorch implementation of a Participant class for Federated Learning. We follow the example presented in this tutorial and it is beneficial to read it before starting this tutorial.

Every example contains two steps:

  • Setting up the Coordinator that waits for the Participants

  • Setting up the SDK Participant together with a suitable model that connects to the Coordinator.

The first part is described in the XAIN-fl repository. Then we can assume that we have our Coordinator waiting for the Participants to join. The next step is to set up the Participant SDK and equip it with a model. We cover the requirements of the Participant Abstract Base Class, give ideas on how to handle a PyTorch model and show how to implement a Federated Learning example. You can find the complete source code here. The example code makes use of typing to be precise about the expected data types.

Participant Abstract Base Class

The SDK provides an abstract base class for Participants which can be imported as

from xain_sdk.participant import Participant as ABCParticipant

A custom Participant should inherit from the abstract base class, like

class Participant(ABCParticipant):

and must implement the train_round() method in order to be able to execute a round of Federated Learning, where each round consists of a certain number of epochs. This method adheres to the function signature

train_round(self, weights: Optional[np.ndarray], epochs: int, epoch_base: int) -> Tuple[np.ndarray, int]

The expected arguments are:

  • weights (Optional[np.ndarray]): Either a Numpy array containing the flattened weights of the global model or None. In the latter case the participant must properly initialize the weights instead of loading them.

  • epochs (int): The number of epochs to be trained during the federated learning round. Can be any non-negative number including zero.

  • epoch_base (int): A global training epoch number (e.g. for epoch dependent learning rate schedules and metrics logging).

The expected return values are:

  • np.ndarray: The flattened weights of the local model which results from the global model after certain epochs of training on local data.

  • int: The number of samples in the train dataset used for aggregation strategies.

The Participant’s base class provides utility methods to set the weights of the local model according to the given flat weights vector, by

get_pytorch_weights(model: torch.nn.Module) -> np.ndarray

and to get a flattened weights vector from the local model, by

set_pytorch_weights(weights: np.ndarray, shapes: List[Tuple[int, ...]], model: torch.nn.Module) -> None

as well as the original shapes of the weights of the local model, by

get_pytorch_shapes(model: torch.nn.Module) -> List[Tuple[int, ...]]

Also, metrics of the current training epoch can be send to a time series data base via the coordinator by

update_metrics(epoch, epoch_base, MetricName=metric_value, ...)

for any number of metrics.

PyTorch Model

Following the tutorial in the PyTorch documentation we prepare the class Net(nn.Module) that contains the model architecture, definition of the forward pass together with optimization round. What is more we include two methods that allows us to enforce model weights to the model and export them after training rounds.

CNN Setup

We use the architecture used in the PyTorch documentation for the CIFAR10 benchmark. We use the following PyTorch packages

from torch import utils
from torchvision import datasets, transforms

to work with the PyTorch CNN packages. We start with preparation of our CNN architecture

def __init__(self) -> None:
    super(Net, self).__init__()
    self.conv1 = nn.Conv2d(3, 6, 5)
    self.pool = nn.MaxPool2d(2, 2)
    self.conv2 = nn.Conv2d(6, 16, 5)
    self.fc1 = nn.Linear(16 * 5 * 5, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)

def forward(self, x: Tensor) -> Tensor:
    x = self.pool(F.relu(self.conv1(x)))
    x = self.pool(F.relu(self.conv2(x)))
    x = x.view(-1, 16 * 5 * 5)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x

The last part of our model class is setting the optimizer together with loss function and training loop.

def train_n_epochs(self, trainloader, number_of_epochs: int) -> None:
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(self.parameters(), lr=0.001, momentum=0.9)

    for epoch in tqdm(range(number_of_epochs), desc="Epochs"):  # loop over the dataset multiple times
        running_loss = 0.0
        for i, data in tqdm(enumerate(trainloader, 0), desc="Batches"):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data
            # zero the parameter gradients
            optimizer.zero_grad()
            # forward + backward + optimize
            outputs = self(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

As we can expect the last function is exactly what we need in our Participant class as the train_round() method.

PyTorch data loader and Participant initialization

In the tutorial we use the standard data transformation and loading as it is described in the PyTorch documentation. The only difference is that we initiate each Participant with a randomised training set.

self.trainset = datasets.CIFAR10(
    root="./data", train=True, download=True, transform=transform
)
self.trainloader = utils.data.DataLoader(
    self.trainset, batch_size=4, shuffle=True, num_workers=2
)

For validation purposes we also create a test set

self.testset = datasets.CIFAR10(
    root="./data", train=False, download=True, transform=transform
)
self.testloader = utils.data.DataLoader(
    self.testset, batch_size=4, shuffle=False, num_workers=2
)

Besides the test and train set, the Participant also contains a model

self.model = Net()

that is an instance of the CNN we defined following the PyTorch tutorial. The utility method for setting the model weights require the original shapes of the weights, obtainable as

self.model.forward(torch.zeros((4, 3, 32, 32)))
self.model_shapes: List[Tuple[int, ...]] = self.get_pytorch_shapes(model=self.model)

where the dummy forward pass is necessary to populate the state dict but has no effect otherwise.

PyTorch Training Round

The implementation of the actual train_round() method consists of three main steps. First, the provided weights of the global model are loaded into the local model, as

if weights is not None:
    self.set_pytorch_weights(weights=weights, shapes=self.model_shapes, model=self.model)

Next, the local model is trained for certain epochs on the local data, whereby the metrics are gathered in each epoch, as

number_samples: int = len(self.trainloader)
# TODO: return metric values from `train_n_epochs`
self.model.train_n_epochs(self.trainloader, epochs)

Finally, the updated weights of the local model and the number of samples of the train dataset are returned, as

weights = self.get_pytorch_weights(model=self.model)
return weights, number_samples

If there are no weights provided, then the participant initializes new weights according to its model definition and returns them without further training, as

else:
    self.init_model()
    number_samples = 0