Neural learning rules for generating flexible predictions and computing the successor representation

  1. Ching Fang
  2. Dmitriy Aronov
  3. LF Abbott
  4. Emily L Mackevicius  Is a corresponding author
  1. Zuckerman Institute, Department of Neuroscience, Columbia University, United States
  2. Basis Research Institute, United States
7 figures and 1 additional file

Figures

The successor representation and an analogous recurrent network model.

(A) The behavior of an animal running down a linear track can be described as a transition between discrete states where the states encode spatial location. (B) By counting the transitions between different states, the behavior of an animal can be summarized in a transition probability matrix T. (C) The successor representation matrix is defined as M=t=0γtTt. Here, M is shown for γ=0.6. Dashed boxes indicate the slices of M shown in (D) and (E). (D) The fourth row of the M matrix describes the activity of each state-encoding neuron when the animal is at the fourth state. (E) The fourth column of the M matrix describes the place field of the neuron encoding the fourth state. (F) Recurrent network model of the SR (RNN-S). The current state of the animal is one-hot encoded by a layer of input neurons. Inputs connect one-to-one onto RNN neurons with synaptic connectivity matrix J=T. The activity of the RNN neurons are represented by x. SR activity is read out from one-to-one connections from the RNN neurons to the output neurons. The example here shows inputs and outputs when the animal is at state 4. (G) Feedforward neural network model of the SR (FF-TD). The M matrix is encoded in the weights from the input neurons to the output layer neurons, where the SR activity is read out. (H) Diagram of the terms used for the RNN-S learning rule. Terms in red are used for potentiation while terms in blue are used for normalization (Equation 4). (I) As in (H) but for the feedforward-TD model (Equation 11). To reduce the notation indicating time steps, we use in place of (t) and no added notation for (t-1).

Figure 2 with 1 supplement
Comparing the effects of an adaptive learning rate and plasticity kernels in RNN-S.

(A) Sample one-minute segments from random walks on a 1 meter circular track. Possible actions in this 1D walk are to move forward, stay in one place, or move backward. Action probabilities are uniform (top), biased to move forward (middle), or biased to stay in one place (bottom). (B) M matrices estimated by the RNN-S model in the full random walks from (A).(C) The proposed learning rate normalization. The learning rate ηj for synapses out of neuron j changes as a function of its activity xj and recency bias λ. Dotted lines are at [0.0,0.5,1.0]. (D) The mean row sum of T over time computed by the RNN-S with an adaptive learning rate (blue) or the RNN-S with static learning rates (orange). Darker lines indicate larger static learning rates. Lines show the average over 5 simulations from walks with a forward bias, and shading shows 95% confidence interval. A correctly normalized T matrix should have a row sum of 1.0. (E) As in (D), but for the mean absolute error in estimating T. (F) As in (E), but for mean absolute error in estimating the real M, and with performance of FF-TD included, with darker lines indicating slower learning rates for FF-TD. (G) Lap-based activity map of a neuron from RNN-S with static learning rate η=10-1.5. The neuron encodes the state at 45cm on a circular track. The simulated agent is moving according to forward-biased transition statistics. (H) As in (G), but for RNN-S with adaptive learning rate. (I) The learning rate over time for the neuron in (G) (orange) and the neuron in (H) (blue). (J) Mean-squared error (MSE) at the end of meta-learning for different plasticity kernels. The pre→post (K+) and post→pre (K-) sides of each kernel were modeled by Ae-1τ. Heatmap indices indicate the values τ s were fixed to. Here, K+ is always a positive function (i.e., A was positive), because performance was uniformly poor when K+ was negative. K- could be either positive (left, “Post → Pre Potentiation") or negative (right, “Post → Pre Depression"). Regions where the learned value for A was negligibly small were set to high errors. Errors are max-clipped at 0.03 for visualization purposes. 40 initializations were used for each K+ and K- pairing, and the heatmap shows the minimum error acheived over all intializations. (K) Plasticity kernels chosen from the areas of lowest error in the grid search from (J). Left is post → pre potentiation. Right is post → pre depression. Kernels are normalized by the maximum, and dotted lines are at one second intervals.

Figure 2—figure supplement 1
Comparing model performance in different random walks.

(a-c) As in Figure 2D–F of the main document, but for a walk with uniform action probabilities.

Figure 3 with 1 supplement
RNN-S requires a stable choice of γB during learning, and can compute SR with any γR (A) Maximum real eigenvalue of the J matrix at the end of random walks under different γB.

The network dynamics were either fully linear (solid) or had a tanh nonlinearity (dashed). Red line indicates the transition into an unstable regime. 45 simulations were run for each γB, line indicates mean, and shading shows 95% confidence interval. (B) MAE of M matrices learned by RNN-S with different γB. RNN-S was simulated with linear dynamics (solid line) or with a tanh nonlinearity added to the recurrent dynamics (dashed line). Test datasets used various biases in action probability selection. (C) M matrix learned by RNN-S with tanh nonlinearity added in the recurrent dynamics. A forward-biased walk on a circular track was simulated, and γB=0.8. (D) The true M matrix of the walk used to generate (C). (E) Simulated population activity over the first ten laps in a circular track with γB=0.4. Dashed box indicates the retrieval phase, where learning is turned off and γR=0.9. Boxes are zoomed in on three minute windows.

Figure 3—figure supplement 1
Understanding the effects of recurrency on stability.

(A) Mean absolute error (MAE) of M matrices learned by RNN-S with different baseline γ and different numbers of recurrent steps in dynamics. Test datasets used various biases in action probability selection. Errors are max-clipped at 101 for visualization purposes. (B) M matrix learned by RNN-S with two recurrent steps in dynamics and baseline γ=0.8. A forward-biased walk on a circular track was simulated. (C) As in (B), but for four recurrent steps. (D) As in (B), but for five recurrent steps. Three examples are shown from different sampled walks to highlight the runaway activity of the network. (E) As in (B) but for the RNN-S activity calculated as (I-γJ)-1. Note that this calculation amounts to an unstable fixed point in the dynamics that cannot be reached when the network is in an unstable regime. (F) Mean absolute error (MAE) in T made by RNN-S with linear dynamics using γB during learning. (G) MAE in M for γR made by RNN-S with linear dynamics using γB during learning. (H) As in (G), but the dynamics now have a tanh nonlinearity.

Figure 4 with 1 supplement
Generalizing the model to more realistic inputs.

(A) Illustration of possible feature encodings ϕ for two spatially adjacent states in green and red. Feature encodings may vary in sparsity level and spatial correlation. (B) Average value of the STDP component (red) and the decorrelative normalization (solid blue) component of the gradient update over the course of a random walk. In dashed blue is a simpler Oja-like independent normalization update for comparison. Twenty-five simulations of forward-biased walks on a circular track were run, and shading shows 95% confidence interval. Input features are 3% sparse, with 10 cm spatial correlation. (C) Top: Example population activity of neurons in the RNN-S using the full decorrelative normalization rule over a 2min window of a forward-biased random walk. Population activity is normalized by the maximum firing rate. Bottom: As above, but for RNN-S using the simplified normalization update. (D) Shifts in place field peaks after a half hour simulation from the first two minutes of a 1D walk. Proportion of shifts in RNN-S with one-hot inputs shown in gray. Proportion of shifts in RNN-S with feature encodings (10% sparsity, 7.5 cm spatial correlation, γR=0.8) shown in blue. Each data point is the average shift observed in one simulated walk, and each histogram is over 40 simulated walks. Solid line indicates the reported measure from Mehta et al., 2000.

Figure 4—figure supplement 1
Comparing place field shift and skew effects for different feature encodings.

(A–D) Average firing rate as a function of position on a circular track for four example neurons. The walk and feature encodings were generated as in Figure 4D of the main text. Each neuron is sampled from a different walk. ‘Before Learning’ refers to firing fields made from the first 2-minute window of the walk. ‘After Learning’ refers to firing fields made from the entire walk. (E–F) As in (A–D), but for two neurons from a walk where the features were one-hot encoded.

Figure 5 with 1 supplement
Fitting successor features to data with RNN-S over a variety of feature encodings.

(A) We use behavioral data from Payne et al, where a Tufted Titmouse randomly forages in a 2D environment while electrophysiological data is collected (replicated with permission from authors). Two example trajectories are shown on the right. (B) Temporal difference (TD) loss versus the spatial correlation of the input dataset, aggregated over all sparsity levels. Here, γR=0.75. Line shows mean, and shading shows 95% confidence interval. (C) As in (B), but measuring TD loss versus the sparsity level of the input dataset, aggregated over all spatial correlation levels. (D) TD loss for RNN-S with datasets with different spatial correlations and sparsities. Gray areas were not represented in the input dataset due to the feature generation process. Here, γR=0.75, and three simulations were run for each spatial correlation and sparsity pairing under each chosen γR. (E) As in (G), but for FF-TD. (F) TD loss of each model as a function of γR, aggregated over all input encodings. Line shows mean, and shading shows 95% confidence interval.

Figure 5—figure supplement 1
Parameter sweep details and extended TD error plots.

(A) The values of P (initial sparsity of random vectors before spatial smoothing) and σ sampled in our parameter sweep for Figures 5 and 6 in the main text. See methods 4.10 for more details of how feature encodings were generated. (B) The values of S (final sparsity of features, measured after spatial smoothing) and σ sampled in our parameter sweep for Figures 5 and 6 in the main text. (C) A sample state encoded by the firing rate of 200 input neurons. Here, s=0.11 and σ=2. (D) As in Figure 5F of the main text, with the results from a random feedforward network included (“Shuffle”). The random network was constructed by randomly drawing weights from the distribution of weights learned by the FF-TD network. The random network is representative of a model without learned structure but with a similar magnitude of weights as the FF-TD model. (E) Spatial correlation of the feature encoding for an example state with the features of all other states. The 14×14 states are laid out in their position in the 2D arena. Here, the sample state is the state in the center of the 2D arena and σ=2.0. (F) As in (E), but for σ=0.0. (G) As in Figure 5D of the main text, but for RNN-S (first row) and FF-TD (second row) with γR=0.4 (left column), γR=0.6 (middle column), and γR=0.8 (right column).

Figure 6 with 1 supplement
Comparing place fields from RNN-S to data.

(A) Dataset is from Payne et al, where a Tufted Titmouse randomly forages in a 2D environment while electrophysiological data is collected (replicated with permission from authors). (B) Distribution of place cells with some number of fields, aggregated over all cells recorded in all birds. (C) Distribution of place cells with some field size as a ratio of the size of the arena, aggregated over all cells recorded in all birds. (D) Average proportion of non-place cells in RNN-S, aggregated over simulations of randomly drawn trajectories from Payne et al. Feature encodings are varied by spatial correlation and sparsity as in Figure 5. Each simulation used 196 neurons. As before, three simulations were run for each spatial correlation and sparsity pairing under each chosen γR. (E) As in (D), but for average field size of place cells. (F) As in (D), but for average number of fields per place cell. (G) As in (D) and (E), but comparing place cell statistics using the KL divergence (DKL) between RNN-S and data from Payne et al. At each combination of input spatial correlation and sparsity, the distribution of field sizes is compared to the neural data, as is the distribution of number of fields per neuron, then the two DKL values are summed. Contour lines are drawn at DKL values of 1, 1.5, and 2 bits. (H) Place fields of cells chosen from the region of lowest KL divergence. (I) As in (G) but for FF-TD. (J) Change in KL divergence for field size as function of γ. Line shows mean, and shading shows 95% confidence interval. (K) Same as (J), but for number of fields.

Figure 6—figure supplement 1
Extended place field evaluation plots.

(A) As in Figure 6E–G of the main text, but for γR=0.4 (left column) and γR=0.8 (right column). In addition, the plots showing KL divergence (in bits) for the distribution of field sizes and number of fields per cell are shown. (B) As in (A) but for FF-TD. (C) A in Figure 6H of the main text, but for FF-TD with γR=0.4 and (D) FF-TD with γR=0.8. (E) Total KL divergence across γR for RNN-S, FF-TD, the random network from Figure 6D (‘Shuffle’), and the split-half noise floor from the Payne et al. dataset (‘Data’). This noise floor is calculated by comparing the place field statistics of a random halves of the neurons from Payne et al. We measure the KL divergence between the distributions calculated from each random half. This is repeated 500 times, and it is representative of a lower bound on KL divergence. Intuitively, it should not be possible to fit the data of Payne et al as well as the dataset itself can.

Appendix 7—figure 1
SR matrices under different forms of normalization.

(A) The resulting SR matrix from a random walk on a circular track for 10minutes, if the synaptic weight matrix exactly estimates the transition probability matrix (as in Equation 4). (B) Model as in (A), but with normalization removed. Thus, J will be equal to the count of observed transitions, i.e. Jij is equal to the number of experienced transitions from state j to state i. We will refer to this as a count matrix. The plot show the maximum eigenvalue of the weight matrix, where an eigenvalue 1 indicates instability (Sompolinsky et al., 1988). (C) As in (B), but with an additional scaling factor α over the weights of the matrix, such that J is multiplied by 1αmax(J). (D) Steady state neural activity of the model in (C) with scaling factor 1.75. (E) As in (D), but the count matrix is instead scaled in a row-by-row fashion. Specifically, we divide each row i of the count matrix by the maximum of row i (and some global scaling factor to ensure stability).

Additional files

Download links

A two-part list of links to download the article, or parts of the article, in various formats.

Downloads (link to download the article as PDF)

Open citations (links to open the citations from this article in various online reference manager services)

Cite this article (links to download the citations from this article in formats compatible with various reference manager tools)

  1. Ching Fang
  2. Dmitriy Aronov
  3. LF Abbott
  4. Emily L Mackevicius
(2023)
Neural learning rules for generating flexible predictions and computing the successor representation
eLife 12:e80680.
https://doi.org/10.7554/eLife.80680