input: data + tff conversion function from CustomClientData

output: TFF model for predicting customer paths

description:

Simulating federated learning on predicting customer paths.

Import relevant modules

# uncomment this cell to get the newest version of tff

# tensorflow_federated_nightly also bring in tf_nightly, which
# can causes a duplicate tensorboard install, leading to errors.
#!pip uninstall --yes tensorboard tb-nightly

#!pip install --quiet --upgrade tensorflow-federated-nightly
#!pip install --quiet --upgrade nest-asyncio
#!pip install --quiet --upgrade tb-nightly  # or tensorboard, but not both
import collections
import matplotlib.pyplot as plt
import numpy as np
import nest_asyncio

nest_asyncio.apply()

from pathlib import Path
from pyarrow import feather
import pandas as pd

import tensorflow as tf
import tensorflow_federated as tff


from ml_federated_customer_journey.customclientdata import (
    create_tff_client_data_from_df,
)
2021-12-23 12:20:40.264164: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0
train_test_client_split = tff.simulation.datasets.ClientData.train_test_client_split

You can also view the results using tensorboard:

%load_ext tensorboard
tff.federated_computation(lambda: "Hello, World!")()
2021-12-23 12:20:42.695475: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcuda.so.1
2021-12-23 12:20:42.929745: E tensorflow/stream_executor/cuda/cuda_driver.cc:328] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
2021-12-23 12:20:42.929813: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (ml-nuutti-1a-m-mem): /proc/driver/nvidia/version does not exist
2021-12-23 12:20:42.930471: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2021-12-23 12:20:42.956578: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:176] None of the MLIR Optimization Passes are enabled (registered 2)
2021-12-23 12:20:42.964904: I tensorflow/core/platform/profile_utils/cpu_utils.cc:114] CPU Frequency: 2095190000 Hz
b'Hello, World!'

Define notebook parameters

You can easily try different setups by running the notebook with papermill using different parameters.

seed = 0
data_filepath_parts = ("data", "preprocessed_data", "data.f")  # for pathlib
test_split = 0.2

NUM_EPOCHS = 20
BATCH_SIZE = 32
SHUFFLE_BUFFER = 100
FEDERATED_UPDATES = 50

Make immediate derivations from the parameters:

np.random.seed(seed)
tf.random.set_seed(seed)
data_filepath = Path.cwd() / Path(*data_filepath_parts)

Load Data

df = feather.read_feather(data_filepath)
df.head()
client_id x y
0 0 [0, 1, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, ... 3
1 0 [1, 1, 2, 2, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, ... 1
2 0 [2, 1, 2, 3, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... 15
3 1 [0, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, ... 3
4 1 [1, 3, 2, 2, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, ... 1

How many data points

df.shape
(21488, 3)

How many features

NUM_FEATURES = df.x[0].shape[0]
NUM_FEATURES
20
NUNIQUE_LABELS = df.y.nunique()

Convert into tff ClientData (training + testing datasets):

client_data = create_tff_client_data_from_df(df, sample_size=1)
train_data, test_data = train_test_client_split(
    client_data, int(df.client_id.nunique() * test_split)
)

Test and train dataset size (number of clients in each set)

len(train_data.client_ids)
2141
len(test_data.client_ids)
998
ELEMENT_SPEC = train_data.element_type_structure
ELEMENT_SPEC
OrderedDict([('x', TensorSpec(shape=(None, 20), dtype=tf.int64, name=None)),
             ('y', TensorSpec(shape=(None,), dtype=tf.int32, name=None))])

Create Federeted ML Process with TFF

def create_keras_model():
    """
    Return new keras model instance
    """
    visible = tf.keras.layers.Input(shape=(NUM_FEATURES,))

    hidden1 = tf.keras.layers.Dense(
        48,
        activation=None,
        name="l1relu",
    )(visible)
    output = tf.keras.layers.Dense(
        NUNIQUE_LABELS + 1,
        activation="softmax",
        name="l3softmax",
    )(hidden1)
    model = tf.keras.Model(inputs=visible, outputs=output)
    return model
def model_fn():
    """
    Create tff model (keras model + data format + loss & metrics)
    """
    # We _must_ create a new model here, and _not_ capture it from an external
    # scope. TFF will call this within different graph contexts.
    keras_model = create_keras_model()
    return tff.learning.from_keras_model(
        keras_model,
        input_spec=collections.OrderedDict(
            x=ELEMENT_SPEC["x"],
            y=ELEMENT_SPEC["y"],
        ),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
    )

Create federated averaging process:

iterative_process = tff.learning.build_federated_averaging_process(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.RMSprop(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.RMSprop(learning_rate=1.0),
)

Initialize federated averaging process:

state = iterative_process.initialize()
WARNING:tensorflow:From /anaconda/envs/customerjourney/lib/python3.8/site-packages/tensorflow_federated/python/core/impl/compiler/tensorflow_computation_transformations.py:59: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`
WARNING:tensorflow:From /anaconda/envs/customerjourney/lib/python3.8/site-packages/tensorflow_federated/python/core/impl/compiler/tensorflow_computation_transformations.py:59: extract_sub_graph (from tensorflow.python.framework.graph_util_impl) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.compat.v1.graph_util.extract_sub_graph`

Create federated evaluation process (validation):

evaluation = tff.learning.build_federated_evaluation(model_fn)

Function for loading client data in batches:

def batch_client_data(client_data, batch_size=BATCH_SIZE):
    batch = [
        client_data.create_tf_dataset_for_client(client_data.client_ids[idx])
        for idx in np.random.choice(
            np.arange(len(client_data.client_ids)), size=BATCH_SIZE
        )
    ]
    return batch

Federated training & evaluation:

%%time
# first iteration
state, train_metrics = iterative_process.next(state, batch_client_data(train_data))
test_metrics = evaluation(state.model, batch_client_data(test_data))

# see progress
print("federated_update  {}, loss={:.2f}, accuracy={:.2f}".format(0, train_metrics['train']['loss'], train_metrics['train']['sparse_categorical_accuracy']))

# save results
metrics_df = pd.DataFrame(
    {
        "federated_update": [0],
        "train_loss": [train_metrics["train"]["loss"]],
        "train_accuracy": [train_metrics["train"]["sparse_categorical_accuracy"]],
        "train_size": [train_metrics["stat"]["num_examples"]],
        "test_loss": [test_metrics["loss"]],
        "test_accuracy": [test_metrics["sparse_categorical_accuracy"]],
    }
)  # , 'test_loss': float, 'test_size':int})

# run federated update cycles
for i in range(FEDERATED_UPDATES):
    # update, get train metrics
    state, train_metrics = iterative_process.next(state, batch_client_data(train_data))
    # evaluate
    test_metrics = evaluation(state.model, batch_client_data(test_data))
    # save results
    metrics_df = pd.concat(
        (
            metrics_df,
            pd.DataFrame(
                {
                    "federated_update": [i + 1],
                    "train_loss": [train_metrics["train"]["loss"]],
                    "train_accuracy": [
                        train_metrics["train"]["sparse_categorical_accuracy"]
                    ],
                    "train_size": [train_metrics["stat"]["num_examples"]],
                    "test_loss": [test_metrics["loss"]],
                    "test_accuracy": [test_metrics["sparse_categorical_accuracy"]],
                }
            ),
        ),
        axis=0,
    )
    # see progress
    print("federated_update  {}, loss={:.2f}, accuracy={:.2f}".format(i + 1, train_metrics['train']['loss'], train_metrics['train']['sparse_categorical_accuracy']))
metrics_df.set_index("federated_update", drop=True, inplace=True)
2021-12-23 12:21:27.136377: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:21:27.656336: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:21:27.748344: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:21:33.631581: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:21:35.353653: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:21:35.392388: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:21:36.808093: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:21:36.908555: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:21:36.981553: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:21:38.516368: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:21:41.828356: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:21:43.524509: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:21:43.530324: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:21:45.048431: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:21:45.100342: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:21:48.485201: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:21:48.509787: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:21:51.612459: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:21:51.612554: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:21:54.912343: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:21:54.972350: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:21:57.984340: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:01.244655: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:01.255567: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:01.265025: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:01.276832: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:02.636345: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:02.729645: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:06.080092: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:06.196943: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:07.739155: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:07.764423: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:07.775073: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:09.325119: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:09.340700: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:10.832119: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:10.837680: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:12.590122: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:14.121172: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:14.140444: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:14.146665: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:17.100732: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:18.959558: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:18.995023: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:22.011549: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:22.020355: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:22.025446: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:22.047956: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:23.838793: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:23.841719: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:25.348562: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:25.391140: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:25.399491: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:25.399578: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:26.892338: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:26.903450: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:28.312364: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:28.376336: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:28.378961: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:28.413330: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:28.418083: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:28.438301: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:30.176286: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:30.210613: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:31.728869: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:31.791241: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:34.796362: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:34.870061: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:36.608354: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:36.609238: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:37.998275: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:38.128351: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:38.169969: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:39.537369: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:39.617985: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:42.934535: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:44.518924: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:44.536359: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:49.264395: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:50.889295: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:53.864860: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:55.628354: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:57.070208: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:58.596341: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:58.638823: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:22:58.639441: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:23:00.108343: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:23:00.110461: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:23:00.169729: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:23:01.936760: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:23:04.959681: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:23:04.966119: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:23:04.981693: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:23:06.482261: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:23:06.553552: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:23:08.320345: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:23:08.339202: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:23:11.368930: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:23:11.425565: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:23:11.428588: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:23:11.430548: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:23:12.919752: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:23:14.824729: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:23:16.393254: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:23:17.890446: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:23:17.902777: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:23:19.475934: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:23:25.901253: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:23:27.525006: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:23:32.094540: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:23:32.112477: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:23:32.144562: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:23:33.941747: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:23:35.410229: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:23:35.424451: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:23:35.479841: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:23:36.872498: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:23:36.971176: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:23:38.447056: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:23:38.525556: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:23:40.268701: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:23:41.746859: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:23:41.751637: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:23:44.912351: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:23:46.676215: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:23:49.799873: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:23:53.104417: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:23:53.140458: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:23:56.161489: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:23:57.488345: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:23:59.472940: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:24:01.010235: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:24:01.032257: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:24:02.489017: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:24:04.062348: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
2021-12-23 12:24:07.927882: W tensorflow/core/kernels/data/model_dataset_op.cc:205] Optimization loop failed: Cancelled: Operation was cancelled
federated_update  0, loss=2.30, accuracy=0.33
federated_update  1, loss=14.91, accuracy=0.07
federated_update  2, loss=12.05, accuracy=0.25
federated_update  3, loss=11.01, accuracy=0.32
federated_update  4, loss=14.65, accuracy=0.09
federated_update  5, loss=9.02, accuracy=0.44
federated_update  6, loss=12.66, accuracy=0.21
federated_update  7, loss=9.11, accuracy=0.43
federated_update  8, loss=11.26, accuracy=0.30
federated_update  9, loss=11.19, accuracy=0.30
federated_update  10, loss=10.50, accuracy=0.34
federated_update  11, loss=11.39, accuracy=0.29
federated_update  12, loss=8.87, accuracy=0.44
federated_update  13, loss=9.65, accuracy=0.39
federated_update  14, loss=9.11, accuracy=0.43
federated_update  15, loss=8.27, accuracy=0.48
federated_update  16, loss=13.42, accuracy=0.16
federated_update  17, loss=6.81, accuracy=0.57
federated_update  18, loss=10.34, accuracy=0.35
federated_update  19, loss=9.33, accuracy=0.42
federated_update  20, loss=9.08, accuracy=0.43
federated_update  21, loss=8.00, accuracy=0.50
federated_update  22, loss=11.33, accuracy=0.29
federated_update  23, loss=10.83, accuracy=0.32
federated_update  24, loss=11.91, accuracy=0.26
federated_update  25, loss=9.72, accuracy=0.39
federated_update  26, loss=11.46, accuracy=0.29
federated_update  27, loss=7.77, accuracy=0.51
federated_update  28, loss=10.68, accuracy=0.33
federated_update  29, loss=8.59, accuracy=0.46
federated_update  30, loss=10.04, accuracy=0.37
federated_update  31, loss=9.78, accuracy=0.38
federated_update  32, loss=9.94, accuracy=0.38
federated_update  33, loss=12.65, accuracy=0.21
federated_update  34, loss=7.63, accuracy=0.52
federated_update  35, loss=11.50, accuracy=0.28
federated_update  36, loss=8.16, accuracy=0.49
federated_update  37, loss=10.84, accuracy=0.32
federated_update  38, loss=7.17, accuracy=0.54
federated_update  39, loss=12.02, accuracy=0.24
federated_update  40, loss=9.21, accuracy=0.43
federated_update  41, loss=10.20, accuracy=0.36
federated_update  42, loss=9.85, accuracy=0.39
federated_update  43, loss=10.54, accuracy=0.34
federated_update  44, loss=9.02, accuracy=0.42
federated_update  45, loss=10.84, accuracy=0.32
federated_update  46, loss=9.34, accuracy=0.42
federated_update  47, loss=12.42, accuracy=0.22
federated_update  48, loss=9.63, accuracy=0.40
federated_update  49, loss=10.28, accuracy=0.35
federated_update  50, loss=10.26, accuracy=0.36
CPU times: user 6min 14s, sys: 14.5 s, total: 6min 28s
Wall time: 3min 1s

Create baseline (centralized model)

Compare federated learning to running same in a centralized setup. The learning speed is not directly comparable, because with centralized model we only do epochs, whereas with federated learning there are decentralized epochs and federated updates alternating.

%%time
# create train and test data for centralized setup
# (use the same train and test data, although it is passed to the models in different ways)
train_data_centralized = (
    train_data.create_tf_dataset_from_all_clients()
    .map(lambda x: (x["x"][0], x["y"][0]))
    .shuffle(SHUFFLE_BUFFER)
    .batch(BATCH_SIZE)
    .prefetch(tf.data.AUTOTUNE)
)

test_data_centralized = (
    test_data.create_tf_dataset_from_all_clients()
    .map(lambda x: (x["x"][0], x["y"][0]))
    .shuffle(SHUFFLE_BUFFER)
    .batch(BATCH_SIZE)
    .prefetch(tf.data.AUTOTUNE)
)

# create centralized model

centralized_model = create_keras_model()

centralized_model.compile(
    loss="sparse_categorical_crossentropy",
    metrics=["sparse_categorical_accuracy"],
    optimizer="RMSprop",
)

# fit and evaluate
history = centralized_model.fit(
    train_data_centralized,
    validation_data=test_data_centralized,
    epochs=NUM_EPOCHS,
    batch_size=BATCH_SIZE,
)

# view results
pd.DataFrame(history.history).plot()
WARNING:tensorflow:AutoGraph could not transform <function <lambda> at 0x7ff4d59c4f70> and will run it as-is.
Cause: could not parse the source code of <function <lambda> at 0x7ff4d59c4f70>: no matching AST found
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING: AutoGraph could not transform <function <lambda> at 0x7ff4d59c4f70> and will run it as-is.
Cause: could not parse the source code of <function <lambda> at 0x7ff4d59c4f70>: no matching AST found
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING:tensorflow:AutoGraph could not transform <function <lambda> at 0x7ff4d4001670> and will run it as-is.
Cause: could not parse the source code of <function <lambda> at 0x7ff4d4001670>: no matching AST found
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING: AutoGraph could not transform <function <lambda> at 0x7ff4d4001670> and will run it as-is.
Cause: could not parse the source code of <function <lambda> at 0x7ff4d4001670>: no matching AST found
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
Epoch 1/20
210/210 [==============================] - 10s 46ms/step - loss: 1.8254 - sparse_categorical_accuracy: 0.4806 - val_loss: 1.4598 - val_sparse_categorical_accuracy: 0.6456
Epoch 2/20
210/210 [==============================] - 10s 46ms/step - loss: 1.4857 - sparse_categorical_accuracy: 0.6144 - val_loss: 1.2795 - val_sparse_categorical_accuracy: 0.6621
Epoch 3/20
210/210 [==============================] - 9s 45ms/step - loss: 1.5108 - sparse_categorical_accuracy: 0.5861 - val_loss: 1.3949 - val_sparse_categorical_accuracy: 0.6128
Epoch 4/20
210/210 [==============================] - 10s 46ms/step - loss: 1.3662 - sparse_categorical_accuracy: 0.6013 - val_loss: 1.2798 - val_sparse_categorical_accuracy: 0.6761
Epoch 5/20
210/210 [==============================] - 10s 45ms/step - loss: 1.2101 - sparse_categorical_accuracy: 0.6661 - val_loss: 1.1067 - val_sparse_categorical_accuracy: 0.6978
Epoch 6/20
210/210 [==============================] - 9s 44ms/step - loss: 1.3494 - sparse_categorical_accuracy: 0.6034 - val_loss: 1.3121 - val_sparse_categorical_accuracy: 0.6304
Epoch 7/20
210/210 [==============================] - 10s 46ms/step - loss: 1.2088 - sparse_categorical_accuracy: 0.6489 - val_loss: 1.0820 - val_sparse_categorical_accuracy: 0.6790
Epoch 8/20
210/210 [==============================] - 10s 46ms/step - loss: 1.1442 - sparse_categorical_accuracy: 0.6738 - val_loss: 1.1076 - val_sparse_categorical_accuracy: 0.6650
Epoch 9/20
210/210 [==============================] - 10s 47ms/step - loss: 1.2855 - sparse_categorical_accuracy: 0.6379 - val_loss: 1.1851 - val_sparse_categorical_accuracy: 0.6738
Epoch 10/20
210/210 [==============================] - 10s 45ms/step - loss: 1.4220 - sparse_categorical_accuracy: 0.5843 - val_loss: 1.3356 - val_sparse_categorical_accuracy: 0.6245
Epoch 11/20
210/210 [==============================] - 10s 46ms/step - loss: 1.3806 - sparse_categorical_accuracy: 0.6065 - val_loss: 1.3161 - val_sparse_categorical_accuracy: 0.6449
Epoch 12/20
210/210 [==============================] - 10s 46ms/step - loss: 1.2324 - sparse_categorical_accuracy: 0.6366 - val_loss: 1.1956 - val_sparse_categorical_accuracy: 0.6534
Epoch 13/20
210/210 [==============================] - 10s 46ms/step - loss: 1.3406 - sparse_categorical_accuracy: 0.6060 - val_loss: 1.2209 - val_sparse_categorical_accuracy: 0.6310
Epoch 14/20
210/210 [==============================] - 10s 45ms/step - loss: 1.3025 - sparse_categorical_accuracy: 0.6123 - val_loss: 1.2510 - val_sparse_categorical_accuracy: 0.6388
Epoch 15/20
210/210 [==============================] - 10s 46ms/step - loss: 1.2434 - sparse_categorical_accuracy: 0.6559 - val_loss: 1.2067 - val_sparse_categorical_accuracy: 0.6654
Epoch 16/20
210/210 [==============================] - 10s 46ms/step - loss: 1.4265 - sparse_categorical_accuracy: 0.6081 - val_loss: 1.3274 - val_sparse_categorical_accuracy: 0.6702
Epoch 17/20
210/210 [==============================] - 10s 45ms/step - loss: 1.2236 - sparse_categorical_accuracy: 0.6579 - val_loss: 1.1487 - val_sparse_categorical_accuracy: 0.6907
Epoch 18/20
210/210 [==============================] - 10s 45ms/step - loss: 0.9367 - sparse_categorical_accuracy: 0.7356 - val_loss: 0.8296 - val_sparse_categorical_accuracy: 0.7944
Epoch 19/20
210/210 [==============================] - 10s 45ms/step - loss: 1.3789 - sparse_categorical_accuracy: 0.5812 - val_loss: 1.3323 - val_sparse_categorical_accuracy: 0.6054
Epoch 20/20
210/210 [==============================] - 10s 45ms/step - loss: 1.3409 - sparse_categorical_accuracy: 0.6163 - val_loss: 1.2852 - val_sparse_categorical_accuracy: 0.6177
CPU times: user 4min 18s, sys: 14.3 s, total: 4min 32s
Wall time: 3min 48s
WARNING:tensorflow:AutoGraph could not transform <function <lambda> at 0x7ff4d59c4f70> and will run it as-is.
Cause: could not parse the source code of <function <lambda> at 0x7ff4d59c4f70>: no matching AST found
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
WARNING:tensorflow:AutoGraph could not transform <function <lambda> at 0x7ff4d4001670> and will run it as-is.
Cause: could not parse the source code of <function <lambda> at 0x7ff4d4001670>: no matching AST found
To silence this warning, decorate the function with @tf.autograph.experimental.do_not_convert
<AxesSubplot:>

Plot results:

fig, ax = plt.subplots(1, constrained_layout=True)
# federated results
metrics_df.plot(ax=ax, y=["train_loss", "test_loss"], color=["royalblue", "red"])
# centralized baseline
ax.axhline(
    history.history["val_loss"][-1],
    linestyle="--",
    color="black",
    label="baseline (centralized computation)",
)
ax.set_ylabel("loss (sparse_categorical_crossentropy)")
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.legend()
ax.set_ylim(ymin=0)
(0.0, 16.483655858039857)
fig, ax = plt.subplots(1, constrained_layout=True)
# plot federated learning results
metrics_df.plot(
    ax=ax, y=["train_accuracy", "test_accuracy"], color=["royalblue", "red"]
)
# centralized baseline
ax.axhline(
    history.history["val_sparse_categorical_accuracy"][-1],
    linestyle="--",
    color="black",
    label="baseline (centralized computation)",
)
# random guessing baseline
ax.axhline(
    1 / NUNIQUE_LABELS, linestyle="-.", color="grey", label="baseline (random quessing)"
)
ax.legend()
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.set_ylim(0, 1)
plt.savefig(Path.cwd() / "results" / "accuracy.png")

With the current setup, we can reach test accuracy of 50% and slightly above (depending on the run). However, at times the model falls back to the level of random quessing.

Conclusions

  • We can do federated learning of customer paths even with quite little data
  • It's doing significantly better than random guessing
  • However, cetralized model stil outperforms the federated computation. More work is required for optimization.