How the network architecture impacts robustness and why this does not matter.

Introduction

Among the many expectations of trustworthy AI is also the desire for robust and consistent predications, even in unexpected situations. In this article, I will highlight the role of the network architecture and its components in creating robustness. I demonstrate with a couple of tech experiments that these effects can be quite significant. For instance, I show a case where a ResNet-18 becomes a random guesser where others models still perform decently. Then comes the twist, the realisation that although these effects can be large, it won’t help creating robust AI developed only in the lab.

Why robustness matters

Robustness refers to a model’ ability to keep making valid predictions for data that is somewhat different from what it was trained on (training set). This is an important consideration, as a model should reliably make predictions in real-world situations, where we often cannot control for the input data quality.

In this context, the jargon is independent and identically distributed data (i.i.d) and out-of-distribution data (o.o.d). Typically, model validation (test set) is done using i.i.d. data, resulting in high accuracies. However, this metric does not inform us on the models’ ability to generalise to data outside this data domain. O.o.d. describes data that has a shift in distribution or more generally has semantically similar data, but is somehow somewhat unexpectedly different. By validating using o.o.d. data, we can see how well a model can generalise or is over-fitting to a train dataset.

Ability to generalise is what we strive for when we want models with high robustness. A great in-depth discussion on robustness, shortcut learning and reasons why they might occur, can be found in Geirhos et al. (2020).

The experiment

My experiment builds upon the experiment from Geirhos. In it, the researchers create their own dataset with images containing a shape of a star or a moon. The caveat in the dataset is that the shape does change position a bit, but always remains in a fixed quadrant.

The dataset contains shapes that are always in a fixed quadrant.

Next, a “model” is trained, and the authors demonstrate that when shapes are positioned in a quadrant in which they were not trained, the model fails to make a correct prediction. The interpretation is that a “shortcut” is learned: rather than learning the shape from its shape, it associates position with shape.

The concept of shortcut-learning is an insightful construct and reminder of what models really are: they are lazy and learn to recognise patterns that can be much different compared to human perception.

What bothered me though, was that a CNN would perform so badly. It was unexpected  that the model was not robust against location. The textbooks on Deep Learning state that CNNs due to their convolutions and striding behaviour should build up tolerances towards variances.

The convolutional filter looks at 9 pixels and "summarises" this into a single pixel. This step is in a loop for multiple steps, creating a "perception field".

I wanted to find out more on this. In specific, I want to find out more what architecture or components in an architecture create robustness. Would it be possible to have an architecture that would not fail in such a spectacular fashion?

In focus were the following conditions:

  1. Testing different architectures: CNN, CNN with Max Pooling, CNN with global average pooling, ResNet-18, ResNet-18 with antialiasing filter
  2. Testing with a new validation set with different out of distribution data. I controlled the modifications such as shape transformation (skewness), scale, rotation, background, and combinations of modifications to see how well architectures could cope with degrees of variances.

Technical discussion on the tested architectures

You might want to skip this part if you are not into interested into the technical discussion of the architectures tested. The summary is this: architectures drop accuracy if tested on data which is different from the test set. However, different architectures degrade differently. On my validation dataset, a ResNet-18 with an anti-aliasing layer had the strongest robustness. Without the anti-aliasing filter, though (just a plain ResNet-18), the network became a random guesser. The models were implemented in Pytorch, code is available in a Google Colab .

The five architectures I experimented with:

1 Fully connected CNN

In this model, there are 3 layers with a stride-2 convolutional layer. As a result, the layers downsample as they go deeper in the network. This effectively leads to “perception field” that increase in size. This should give it invariance to modifications such as rotation and position.

A fully connected CNN downsamples in deeper layers. This should give it some tolerance to invariances within images.

2 CNN with max pooling

Maximum pooling layers should help to increase the size of the perceptional field of the feature maps. In this model, I have used 3 convolutional layers, each following with a max pooling layer. The final layer is flattened and used as input for the last linear layer. It is very similar to the original architecture for the MNIST dataset. In fact I did a check and the model scores an accuracy of 0.98 on MNIST.

CNN with max pooling.

3 CNN with global average pooling

The ResNet architecture is still an important architecture for vision models. ResNet has beside the residual skip connections also an “global average pooling” after all convolutional layers and before a feedforward net starts.

I made a simple model with 3 non-subsampling convolutions. This is followed with a global average pooling before the feedforward layer starts. This means there is no striding or pooling behaviour: the feature maps remain in the same size as the input. It only comes down to the effect of the final global average pooling.

Simplified visualisation of the network.

The global average pooling layer takes every feature map and take an average. As a result, the location of the pixels become irrelevant. It should therefor lead to a high tolerance to positional variance.

Taking an average should complete remove any locational properties a feature map contains.

4 ResNet-18

The ResNet architecture has some additional tricks up its sleeve to make it a good vision architecture:

  1. Batch normalisation
  2. Skip connections
  3. Global average pooling

Although ResNet is a bit older (2015) it is still a respectable architecture and often a reference in benchmarks. I choose the ResNet-18 version which is not very deep, but compared to the CNN architectures in this experiment, it contains quite a few hidden layers. Would adding depth increase robustness?

5 ResNet-18 with anti-aliasing layers

Zhang (2020) proposed an additional layer type to ResNet that should make it highly tolerant towards rotations and transformations. The intuition behind the layer is that before a downsampling method is applied, the image is first blurred. This is an “anti-aliasing” method, which makes sure that edges remain smooth. Conveniently, the authors provided an implementation for pytorch that operates on a resnet-18. This allows me to compare the results.

Experiment 1: testing against positional invariance

In the original experiment from Geirhos, the experiment showed that shapes were not recognised if it was in located in a quadrant it was not trained for. Using a similar verification dataset I tested the trained models to see if they were more robust against position.

In the dataset to test robustness the shapes are positoned in their opposite quadrants. This follows the original experiment from Geirhos et al..

The results indeed indicate that a CNN alone will not provide robustness against (large) positional invariance, even when there is striding or pooling behaviour. The introduction of a global average pooling layer though, results in high tolerances against positional variance.

Accuracy of model using o.o.d. data with images

Experiment 2: testing against other invariances

Three out of 5 architectures maxed out in terms of accuracy on this o.o.d. dataset. As a next step, I created a more tricky dataset with further transformations to the shapes to see which architecture proved most robust. In this dataset, I made structured modifications in terms of scale, transformation of shape, rotation, combinations of modifications and added different backgrounds.

Example images of the variances I applied. An argument can be made that some distortions are no longer a moon or a star.

In the table below are the accuracies displayed of dataset 1 (only positional variance) and dataset 2 (my own dataset with many types of transformations).

Accuracies on both o.o.d. datasets

What I find surprising is how the CNN with only global average pooling “outperformed” the ResNet-18. This means that the additional components ResNet contains (layer depth, batch norm and residual connections) must have contributed to a loss of robustness. On the other hand, adding the anti-aliasing filter from Zhang leads to an increase in robustness for the Resnet architecture.

Some further observations:

  1. Model 1 and 2 had no robustness against large positional invariances. However, if the shape was located in correct quadrant it was trained for, it became tolerant towards rotation and transformation variances. This indeed fits “common knowledge” that convolutions introduce robustness against (small) invariances within the scope of the perception field.
  2. All models become random guessers whenever the background image changes is replaced with either a pattern or photo.
Not a single architecture can handle this change.

Experiment 3: data augmentation

Another strategy to make a model more robust against unexpected input is by data augmentation. Data augmentation extends the dataset with image modifications. Examples are changes in brightness, cropping, skewing, rotation etc. With data augmentation I could further increase the accuracy from 80% to 88%.

The twist: the fallacy in this line of thinking

The results of these experiments suggest that the choice of network architecture plays a central role in creating robustness. One could be tempted to just stick to a ResNet-18 with anti-aliasing and add some data augmentation as a general solution to vision problem. I was pleased with the result until I made an important realisation after reading a very informative paper:

Beede et. al, (2020) describe their experience of deploying a vision classifier within the medical domain. The vision model could help with the identification of early diabetic retinopathy, a condition that can make you go blind if not detected early enough. Having great accuracy scores during production time, the team was confident it was ready for a trial in a clinical setting. From the safe and controlled havens of their offices, the model was deployed in Thailand. Here the same type of scans were conducted but in different conditions. As a result, the scans were made in different light conditions that could not be controlled for in the hospital and resulted in a large source of error.

Of course, there are follow-up steps to take now. Perhaps a data augmentation strategy introducing modifications such as variations of lightness or contrast would help make the model robust in this setting.

The realisation I made is this: knowledge on what architecture or data augmentation technique is best, only becomes available after you begin to understand how data from the real-world looks like. A deployed model is likely to always face data that looks different from the training set. However, generalisability or robustness as general principle is hard to pursue. You cannot just add “general generalisability” with some fancy layer. It is much better to deduce the right strategy from the data you receive from the real world. Architectures have characteristics and the right choice of architecture and data augmentation technique, depends on this data.

Final thoughts

This is a similar line of conclusion to the current en vogue Data Centric AI approach. In this approach, the focus is on collecting more suited real-world data over fiddling with architectures to find solutions. Agile ML practises also promote this type of behaviour.

This is a bit cheating and not very “intelligent”. This approach means you inject human knowledge in the model, and also means it is likely new settings will emerge where the model’s performance might degrade. Maintenance of models will be an important part of the AI product life-cycle.

I believe this is where a product lifecycle challenge emerges. With normal "software" you can run inhouse quality tests prior to running a beta test. With AI/ML, beta tests are criticial yet cannot fully be prepared for. My take-away is that such initial deployments must have strong quality control aspects with them.

Architecture choice can matter and affect robustness. However, there is no general solution to add general robustness. What works best can only be deduced from inspecting and understanding real-world data. The context of deployment must be understood early in the design phase. Additional beta-tests must have integrated quality mechanisms to make sure no harm is done.

References

Beede, E., Baylor, E., Hersch, F., Iurchenko, A., Wilcox, L., Ruamviboonsuk, P., & Vardoulakis, L. M. (2020). A Human-Centered Evaluation of a Deep Learning System Deployed in Clinics for the Detection of Diabetic Retinopathy. In Proceedings of the 2020 CHI Conference on Human Factors in Computing Systems (pp. 1–12). Association for Computing Machinery.

Zhang, R. (2019). Making convolutional networks shift-invariant again. 36th International Conference on Machine Learning, ICML 2019, 2019-June, 12712–12722.

Geirhos, R., Jacobsen, J.-H., Michaelis, C., Zemel, R., Brendel, W., Bethge, M., & Wichmann, F. A. (2020). Shortcut Learning in Deep Neural Networks.