xain_sdk package

Provides xain package SDK

Submodules

xain_sdk.logger module

Logging configuration.

xain_sdk.logger.get_logger(name, level=None)

Wrap the python logger with the default configuration of structlog.

Parameters
  • name (str) – Identification name. For module name pass name=__name__.

  • level (Optional[int]) – Threshold for this logger.

Return type

BoundLoggerLazyProxy

Returns

The wrapped python logger with the default configuration of structlog.

xain_sdk.logger.set_log_level(level)

Set the log level on the root logger.

Since the root logger log level is inherited by all the loggers by default, this is like setting a default log level.

Parameters

level (Union[str, int]) – The log level, as documented in the Python standard library.

Return type

None

xain_sdk.participant module

Participant API

class xain_sdk.participant.InternalParticipant(participant, store=<xain_sdk.store.NullObjectStore object>)

Bases: object

Internal representation that encapsulates the user-defined Participant class.

Parameters
  • participant (Participant) – A user provided implementation of a participant.

  • store (AbstractStore) – A client for a storage service.

train_round(weights, epochs, epoch_base)

Wrap the user provided participant train_round() method.

The metrics gathered by the user are passed along as a JSON string.

Parameters
  • weights (Optional[ndarray]) – The weights of the model to be trained.

  • epochs (int) – The number of epochs to be trained.

  • epoch_base (int) – The global training epoch number.

Return type

Tuple[ndarray, int, str]

Returns

The updated model weights, the number of training samples and the metrics.

write_weights(round, weights)

A wrapper for write_weights().

Return type

None

class xain_sdk.participant.Participant

Bases: abc.ABC

An abstract participant for federated learning.

Parameters
  • metrics – A dictionary to gather the metrics of the current training round.

  • dummy_id – A fake id for the participant. Will be replaced later on.

static get_pytorch_shapes(model)

Get the shapes of the weights of a pytorch model.

Note

This will only work with models which already did a forward pass at least once.

Parameters

model (Module) – A pytorch model.

Returns

The shapes of the model weights per

layer.

Return type

List[Tuple[int, ..]]

static get_pytorch_weights(model)

Get the flattened weights vector from a pytorch model.

Note

This will only work with models which already did a forward pass at least once.

Parameters

model (Module) – A pytorch model.

Returns

The vector of the flattened model weights.

Return type

ndarray

static get_tensorflow_shapes(model)

Get the shapes of the weights of a tensorflow model.

Parameters

model (Model) – A tensorflow model.

Returns

The shapes of the model weights per

layer.

Return type

List[Tuple[int, ..]]

static get_tensorflow_weights(model)

Get the flattened weights vector from a tensorflow model.

Parameters

model (Model) – A tensorflow model.

Returns

The vector of the flattened model weights.

Return type

ndarray

static set_pytorch_weights(weights, shapes, model)

Set the weights of a pytorch model.

Parameters
  • weights (ndarray) – A vector of flat model weights.

  • shapes (List[Tuple[int, ..]]) – The original shapes of the pytorch model weights.

  • model (Module) – A pytorch model.

Return type

None

static set_tensorflow_weights(weights, shapes, model)

Set the weights of a tensorflow model.

Parameters
  • weights (ndarray) – A vector of flat model weights.

  • shapes (List[Tuple[int, ..]]) – The original shapes of the tensorflow model weights.

  • model (Model) – A tensorflow model.

Return type

None

abstract train_round(weights, epochs, epoch_base)

Train a model in a federated learning round.

A model is given in terms of its weights and the model is trained on the participant’s dataset for a number of epochs. The weights of the updated model are returned in combination with the number of samples of the train dataset.

Any metrics that should be returned to the coordinator must be gathered via the participant’s update_metrics() utility method per epoch.

If the weights given are None, then the participant is expected to initialize the weights according to its model definition and return them without training.

Parameters
  • weights (Optional[ndarray]) – The weights of the model to be trained.

  • epochs (int) – The number of epochs to be trained.

  • epoch_base (int) – The global training epoch number.

Return type

Tuple[ndarray, int]

Returns

The updated model weights and the number of training samples.

update_metrics(epoch, epoch_base, **kwargs)

Update the metrics for the current training epoch.

Metrics are expected as key=value pairs where the key is a name for the metric and the value is any value of the metric from the current epoch. Values can be scalars/lists/arrays of numerical values and must be convertible to floats. If a metric is already present for the current epoch, then its values will be overwritten.

Examples

update_metrics(0, 0, Accuracy=0.8, Accuracy_per_Category=[0.8, 0.7, 0.9]) update_metrics(0, 0, F1_per_Category=np.ndarray([0.85, 0.9, 0.95]))

Parameters
  • epoch (int) – The local training epoch number.

  • epoch_base (int) – The global training epoch number.

  • kwargs (Any) – The metrics names and values.

Return type

None

xain_sdk.participant_state_machine module

Module implementing the networked Participant using gRPC.

class xain_sdk.participant_state_machine.ParState

Bases: enum.Enum

Enumeration of Participant states.

DONE = 3
TRAINING = 2
WAITING = 1
class xain_sdk.participant_state_machine.StateRecord(state=<ParState.WAITING: 1>, round=-1)

Bases: object

Thread-safe record of a participant’s state and round number.

lookup()

Get the state and round number.

Return type

Tuple[ParState, int]

Returns

The state and round number.

update(state)

Update the state.

Parameters

state (ParState) – The state to update to.

Return type

None

wait_until_selected_or_done()

Wait until the participant was selected for training or is done.

Return type

ParState

Returns

The new state the participant is in.

xain_sdk.participant_state_machine.begin_training(state_record, channel, participant)

Perform actions in the Participant state TRAINING.

Parameters
  • state_record (StateRecord) – The participant’s state record.

  • channel (Channel) – A gRPC channel to the coordinator.

  • participant (InternalParticipant) – The participant for local training.

Return type

None

xain_sdk.participant_state_machine.begin_waiting(state_record, channel, participant)

“Perform actions in the Participant state WAITING.

Parameters
  • state_record (StateRecord) – The participant’s state record.

  • channel (Channel) – A gRPC channel to the coordinator.

  • participant (InternalParticipant) – The participant for local training.

Return type

None

xain_sdk.participant_state_machine.end_training_round(channel, weights, number_samples, metrics)

Start a training round completion exchange with a coordinator.

The locally trained model weights, the number of samples and the gathered metrics are sent.

Parameters
  • channel (Channel) – A gRPC channel to the coordinator.

  • weights (ndarray) – The weights of the locally trained model.

  • number_samples (int) – The number of samples in the training dataset.

  • metrics (str) – Metrics data.

Return type

None

xain_sdk.participant_state_machine.message_loop(channel, state_record, terminate)

Periodically send (and handle) heartbeat messages in a loop.

Parameters
  • channel (Channel) – A gRPC channel to the coordinator.

  • state_record (StateRecord) – The participant’s state record.

  • terminate (Event) – An event to terminate the message loop.

Return type

None

xain_sdk.participant_state_machine.rendezvous(channel)

Start a rendezvous exchange with a coordinator.

Parameters

channel (Channel) – A gRPC channel to the coordinator.

Return type

None

xain_sdk.participant_state_machine.start_participant(participant, config)

Top-level function for the participant’s state machine.

After rendezvous and heartbeat initiation, the Participant is WAITING. If selected to train for the current round, it moves to TRAINING, otherwise it remains in WAITING. After training is complete for the round, it moves back to WAITING. When there is no more training to be done, it moves to the terminal state DONE.

Parameters
  • participant (Participant) – The participant for local training.

  • config (Config) – A valid config.

Return type

None

xain_sdk.participant_state_machine.start_training_round(channel)

Start a training round initiation exchange with a coordinator.

The decoded contents of the response from the coordinator are returned.

Parameters

channel (Channel) – A gRPC channel to the coordinator.

Return type

Tuple[ndarray, int, int]

Returns

A tuple (weights, epochs, epoch_base) where weights is the weights of a global model to train on, epochs is the number of epochs to train, and epoch_base is the epoch base of the global model.

xain_sdk.participant_state_machine.training_round(channel, participant, round)

Initiate a training round exchange with a coordinator.

Begins with start_training_round. Then performs local training computation using the participant. Finally, completes with end_training_round.

In case of empty weights from the coordinator (i.e. a 0th round for weights initialization) the aggregation meta data and metrics from the participant are ignored.

Parameters
  • channel (Channel) – A gRPC channel to the coordinator.

  • participant (InternalParticipant) – The local participant.

  • round (int) – round number.

Raises
  • TypeError – If the model weights received from the participant’s local training round are not of type ndarray.

  • TypeError – If the aggregation meta data received from the participant’s local training round is not of type int.

  • ValueError – If the aggregation meta data received from the participant’s local training round is negative.

  • TypeError – If the metrics received from the participant’s local training round are not of type str.

Return type

None

xain_sdk.participant_state_machine.transit(state_record, heartbeat_response)

Participant state transition function on a heartbeat response.

Updates the state record.

Parameters
  • state_record (StateRecord) – The updatable state record of the participant.

  • heartbeat_response (HeartbeatResponse) – The heartbeat response from the coordinator.

Return type

None

xain_sdk.store module

This module provides classes for weights storage.

It currently only works with services that provides the AWS S3 APIs.

class xain_sdk.store.AbstractStore

Bases: abc.ABC

An abstract class that defines the API a store must implement.

abstract write_weights(round, weights)

Store the given weights, corresponding to the given round.

Parameters
  • round (int) – The round number the weights correspond to.

  • weights (ndarray) – The weights to store.

Return type

None

class xain_sdk.store.NullObjectStore

Bases: xain_sdk.store.AbstractStore

A dummy store that does not do anything.

write_weights(round, weights)

Return without doing anything.

Parameters
  • round (int) – The round number the weights correspond to (unused).

  • weights (List[ndarray]) – The weights to store (unused).

Return type

None

class xain_sdk.store.S3Store(config)

Bases: xain_sdk.store.AbstractStore

A store for services that offer the AWS S3 API.

Parameters
  • config (StorageConfig) – The storage configuration (endpoint URL, credentials, etc.).

  • s3 – The S3 bucket.

write_weights(round, weights)

Store the given weights, corresponding to the given round.

Parameters
  • round (int) – The round number the weights correspond to.

  • weights (ndarray) – The weights to store.

Return type

None