Tensorflow Keras Participant Example

This is an example of a Tensorflow Keras implementation of a Participant for federated learning.

We cover the requirements of the Participant Abstract Base Class, give ideas on how to handle a TF Keras Model and TF Keras Data in the Participant, and show how to implement a federated learning TF Keras Training Round. You can find the complete source code here. The example code makes use of typing to be precise about the expected data types.

Participant Abstract Base Class

The SDK provides an abstact base class for Participants which can be imported as

from xain_sdk.participant import Participant as ABCParticipant

A custom Participant should inherit from the abstract base class, like

class Participant(ABCParticipant):

and must implement the train_round() method in order to be able to execute a round of federated learning, where each round consists of a certain number of epochs. This method adheres to the function signature

train_round(self, weights: Optional[np.ndarray], epochs: int, epoch_base: int) -> Tuple[np.ndarray, int]

The expected arguments are:

  • weights (Optional[np.ndarray]): Either a Numpy array containing the flattened weights of the global model or None. In the latter case the participant must properly initialize the weights instead of loading them.

  • epochs (int): The number of epochs to be trained during the federated learning round. Can be any non-negative number including zero.

  • epoch_base (int): A global training epoch number (e.g. for epoch dependent learning rate schedules and metrics logging).

The expected return values are:

  • np.ndarray: The flattened weights of the local model which results from the global model after certain epochs of training on local data.

  • int: The number of samples in the train dataset used for aggregation strategies.

The Participant’s base class provides utility methods to set the weights of the local model according to the given flat weights vector, by

get_tensorflow_weights(model: tf.keras.Model) -> np.ndarray

and to get a flattened weights vector from the local model, by

set_tensorflow_weights(weights: np.ndarray, shapes: List[Tuple[int, ...]], model: tf.keras.Model) -> None

as well as the original shapes of the weights of the local model, by

get_tensorflow_shapes(model: tf.keras.Model) -> List[Tuple[int, ...]]

Also, metrics of the current training epoch can be send to a time series data base via the coordinator by

update_metrics(epoch, epoch_base, MetricName=metric_value, ...)

for any number of metrics.

TF Keras Model

A TF Keras model definition might either be loaded from a file, generated during the initialization of the Participant, or even generated on the fly in a train_round(). Here, we present a simple dense neural network for classification generated during the Participant’s initialization. The example model consists of an input layer holding 10 parameters per sample, as

input_layer: Tensor = Input(shape=(10,), dtype="float32")

Next, it has a fully connected hidden layer with 6 relu-activated units, as

hidden_layer: Tensor = Dense(
    units=6,
    activation="relu",
    use_bias=True,
    kernel_initializer="glorot_uniform",
    bias_initializer="zeros",
)(inputs=input_layer)

Finally, it has a fully connected output layer with 2 softmax-activated units, as

output_layer: Tensor = Dense(
    units=2,
    activation="softmax",
    use_bias=True,
    kernel_initializer="glorot_uniform",
    bias_initializer="zeros",
)(inputs=hidden_layer)

The model gets compiled with an Adam optimizer, the categorical crossentropy loss function and the categorical accuracy metric, like

self.model: Model = Model(inputs=[input_layer], outputs=[output_layer])
self.model.compile(optimizer="Adam", loss="categorical_crossentropy", metrics=["categorical_accuracy"])

The utility method for setting the model weights require the original shapes of the weights, obtainable as

self.model_shapes: List[Tuple[int, ...]] = self.get_tensorflow_shapes(model=self.model)

TF Keras Data

The data on which the model will be trained, can either be loaded from a data source (e.g. file, bucket, data base) during the initialization of the Participant or on the fly in a train_round(). Here, we employ randomly generated placeholder data as an example. This is by no means a meaningful dataset, but it should be sufficient to convey the overall idea. The dataset for training gets shuffled and batched, like

self.trainset: Dataset = Dataset.from_tensor_slices(
    tensors=(np.ones(shape=(80, 10), dtype=np.float32), np.eye(N=80, M=10, dtype=np.float32))
).shuffle(buffer_size=80).batch(batch_size=10)

while the datasets for validation and testing only get batched, like

self.valset: Dataset = Dataset.from_tensor_slices(
    tensors=(np.ones(shape=(10, 10), dtype=np.float32), np.eye(N=10, M=10, dtype=np.float32))
).batch(batch_size=10)
self.testset: Dataset = Dataset.from_tensor_slices(
    tensors=(np.ones(shape=(10, 10), dtype=np.float32), np.eye(N=10, M=10, dtype=np.float32))
).batch(batch_size=10)

TF Keras Training Round

The implementation of the actual train_round() method consists of three main steps. First, the provided weights of the global model are loaded into the local model, as

if weights is not None:
    self.set_tensorflow_weights(weights=weights, shapes=self.model_shapes, model=self.model)

Next, the local model is trained for certain epochs on the local data, whereby the metrics are gathered in each epoch, as

number_samples: int = 80
for epoch in range(epochs):
    self.model.fit(x=self.trainset, verbose=2, shuffle=False)
    metrics: List[np.ndarray] = self.model.evaluate(x=self.valset, verbose=0)
    self.update_metrics(epoch, epoch_base, Loss=metrics[0], Accuracy=metrics[1])

Finally, the updated weights of the local model and the number of samples of the train dataset are returned, as

weights = self.get_tensorflow_weights(model=self.model)
return weights, number_samples

If there are no weights provided, then the participant initializes new weights according to its model definition and returns them without further training, as

else:
    self.init_model()
    number_samples = 0