Computational Complexity Of Neural Networks
Why are neural networks so slow?
Why are neural networks so slow?
In order to motivate why we separate the training and inference phases of neural networks, it can be useful to analyse the computational complexity.
This essay assumes familiarity with analytical complexity analysis of algorithms, and hereunder big-O notation. If you need a recap, you should read the essay on computational complexity before continuing.
Looking at inference part of a feed forward neural network, we have forward propagation.
Finding the asymptotic complexity of the forward propagation procedure can be done much like we how we found the run-time complexity of matrix multiplication.
Before beginning, you should be familiar with the forward propagation procedure.
We assume the input-vector can be described as:
Where the first element is the bias unit: .
The input is treated in the same as any other activation matrix, and has the index: . The zeroth element, is as usual the bias unit with a value of 1.
About forward propagation, we can write:
Where is the activation function which is evaluated elementwise. We therefore know that has the same dimensions as .
We see that for each layer a matrix multiplication, and an activation function is computed. We know from the the previous essay that naive matrix multiplication has a asymptotic run-time of , and since is an elementwise function, we know that it has a run-time of .
By analysing the dimensions of a feed forward neural network, we find:
More generally, we can write:
Where is the number of neurons including the bias unit in layer
Recalling from the previous post, we see that:
From this we find that:
Where is the number of multiplications performed, and is how many times we apply the activation function,
This gives us:
When analysing matrix algorithms, it's common to assume that the matrices are quadratic; that is they have the same number of rows, as columns. By doing this, we find that:
If we once again assume that there are the same number of neurons in each layer, and that the number of layers equal the number of neurons in each layer we find:
The same can be done for the activations:
The total run-time therefore becomes:
We can find the run-time complexity of backpropagation in a similar manner.
Before beginning, you should be familiar with the backpropagation procedure.
We can safely ignore as it will be in the order of 1:
This gives us:
If we assume that there are
The total run-time for the delta error then becomes:
If we assume there are
In order to find all the weights between a layer, we get:
Plugging this into gradient descent we get:
If we assume
So by assuming that gradient descent runs for
We see that the learning phase (backpropagation) is slower than the inference phase (forward propagation). This is even more pronounced by the fact that gradient descent often has to be repeated many times.
In fact, gradient descent has a convergence rate of for a convex function where is the error of the final hypothesis.
This results in a large constant factor that has real world consequences, but which big-O doesn't show.
One way of making algorithms run faster is by using parallel execution by forexample running the matrix operations on a GPU.
GPUs are specifically designed to run many matrix operations in parallel since 3D geometry and animation can be expressed as a series of linear transformations.
This is also why we usually train neural networks on GPUs.
It's worth mentioning that in 1988 Pitt and Valient formulated an argument that if RP NP, which is currently not known, and if it's NP-HARD to differentiate realizable hypotheses from unrealizable hypotheses, then a correct hypothesis
This, however, doesn't concur with the complexity we found for backpropagation which is only P.
Whether a function has realizable examples is defined as:
Here we see that a hypothesis is realizable if an experimential hypothesis with an error of less than can correctly guess all the examples within that error-margin.
Here, would be a neural network.
Shalev-Swartz argued against this by classifying neural networks, specifically deep neural networks, as doing improper learning by letting , and argued that can still agree with
The best run-time for neural networks is an area of active research.
We have derived the computational complexity of a feed forward neural network, and seen why it's attractive to split the computation up in a training and a inference phase since backpropagation, , is much slower than the forward propagation, .
We have considered the large constant factor of gradient descent required to reach an acceptable accuracy which strengthens the argument.
Furthermore, we have discussed some theoretical aspects of learning representations of functions, and hereunder the role of neural networks, but we were unable to reach a definitive conclusion.
When implementing neural networks, it's often the case that all the samples are collected into a matrix with the dimensions where is the total number of samples in the trainingset. This is done to save time on allocating memory, and is primarily a practical problem which is why we won't consider it further in the theory. ↩︎