Extracting latent variables from sensory stimuli.

Animals and humans must infer latent variables y such as object identity, orientation, and velocity from ambiguous and noisy sensory stimuli. This task is computationally demanding because sensory stimuli x are typically entangled (left). Untangling these input manifolds to approximate the underlying latent variables ct (right) requires multi-layer networks (middle). How our brains achieve this untangling remains a central question in neuroscience.

Learning latent variables from motion stimuli without reconstruction and labels.

(a) Schematic of the network model. RPL consists of different network components. First, the “encoder” encodes sensory stimuli x into intermediate embeddings z. Subsequently, the “integrator” synthesizes these embeddings into the internal representation c. Finally, the “predictor” network predicts future embeddings zt from ctyielding the associated prediction error (PE). The learning objective is to minimize this PE. (b) Synthetic data generation paradigm with known latent variables. We generate motion trajectories by moving a 2D animal image, chosen randomly from one of eight categories (right), within a square arena (left). Each frame is characterized by the object position r1/2, its scale to mimic position in the r3 direction, and the orientation angle θ. For each trajectory we randomly select a constant velocity v and angular velocity ω = θ̇. To obtain pixel images that serve as network inputs, we further place these motion trajectories on a gray background and add Gaussian pixel noise, and elastic collisions at the borders (Methods). (c) Two example trajectories shown frame-by-frame. (d) Scatter plots of linearly decoded latent variables from learned internal representations c for a representative sample of cat (blue) and dog (red) input images. Velocity (v1, v2) and object identity S (left). Orientation (a, b) and S (right). (e) Decoded velocity components 1/2 as a function of ground truth on held-out data. (f) Same as (e) but for as function of θ. The model (RPL) has formed untangled and structured representations of the latent variables. (g) Linear readout accuracy of animal categories S (left) and coefficient of determination (R2) (right) for position r, velocity, orientation θ = (a, b), and angular speed |ω| for linear decoders trained on the internal representations c for n = 3 independent networks, trained with the RPL and invariance learning (IL) objectives. Decoding performance for n = 1 supervised end-to-end training (Sup.) and a randomly initialized network (Rand.) are given for reference. RPL decoding accuracy for object identity, and R2 values are close to the performance of networks trained end-to-end with supervised learning, whereas IL fails to accurately represent the velocity, orientation, and angular speed.(h,i,j) Same as d-f but for IL.

RPL circuits learn an internal world model.

(a) Setup for simulating trajectories using internal representations. We added a dynamic switch allowing to replace feedforward input to the integrator with internal predictions (cf. Fig. 1). To simulate future trajectories, we provide the network with external-world sensory input during an initial inference period. Afterwards, we replaced feedforward input to the integrator with predicted embeddings in an autoregressive manner. (b) Example of a decoded trajectory (red) compared to the ground truth (black) in the r1-r2 plane (left) and r3 as a function of time (top right; cf. Fig. 2a) as well as the object decoder logits (bottom right). In this example we used an initial inference period of eight frames of a cat trajectory (shaded) followed by 16 steps of internal simulation. (c) Decoded velocity and orientation during the inference period (shaded) and simulation. The decoded values closely follow the ground truth. (d) Mean-squared error of the decoded quantities compared to ground-truth as a function of time. The error increases monotonically over simulation time.

RPL learns at different levels of abstraction.

(a) An example sequence of digit triplets separated by zeros. (b) Schematic of the sequence generation paradigm. We generated triplet sequences containing randomly ordered handwritten digits sampled from one of three digit clusters. Transitions between clusters took place with probability 0.2 accompanied by a “zero”. For any given sequence, each digit instance was randomly chosen from the MNIST dataset. (c) Schematic of the tasks used to investigate the emergence of abstraction by asking how well the network representations encode cluster, digit, and triplet identity respectively. Solving these tasks requires both a mapping from images to abstract digit and cluster representations as well as knowledge about the transition structure. To quantify the representational quality, we trained a linear classifier for each task (Methods). (d) Linear decoding accuracy of the cluster, digit, and triplet identities for RPL and IL for n = 3 independently trained networks. Decoding accuracy for end-to-end supervised learning (solid line) and a randomly initialized network (dashed line) are given for reference. RPL achieves high accuracy on all tasks, whereas IL results in a preference for cluster identity. (e) Principal component (PC) projections of the learned representations colored by cluster identity. Both RPL and IL clearly separate the clusters. (f) Same as (e) but colored by temporal position within a triplet. Only RPL yields representations of the abstract temporal transition structure.

RPL extracts latent variables from real-world data.

(a) Schematic of real-world video data of behaving mice in an open-field arena [47]. We used annotated key points to compute position rN, velocity vT, and orientation θ as proxies of latent variables for evaluation (Methods). (b) R2 values of decoded position, velocity, and orientation for RPL and IL from n = 3 independently trained networks. Decoding performance for a randomly initialized network (dashed line) shown for reference. (c) Schematic of real-world speech data from the Librispeech corpus [49]. We relied on speaker identity and phoneme labels to assess model performance. (d) Linear decoding accuracy of speaker identity and phoneme sequence for RPL and IL for n = 3 independently trained models. Decoding from a randomly initialized network given for reference (dashed line). RPL has high decoding accuracy for speaker identity and phonemes. In contrast, IL yields high accuracy only for the speaker identity.

RPL learns successor-like representations comparable to human V1.

(a) Schematic of the full and partial sequences used in the experiment. (b) Relative BOLD signal responses reproduced in human V1 [45]. Relative stimulus response to partial sequence stimuli (left). Individual responses to partial sequences (right). Partial sequence stimulation caused brain activity encoding successor locations but not predecessor locations. Error bars represent the standard error of the mean. Data extracted and reproduced from Ekman et al. [45]. (c) Relative neuronal selectivity matrix extracted from the RPL model representations (n = 35 models). The negative values below the diagonal indicate inhibition of activity at predecessor locations. Positive values above the diagonal indicate that the successors’ representations are driven by the stimulus. (d) Average neural activity for all partial sequences at non-stimulated and stimulated locations after removing the baseline activity as in (b). Each dot is a trained model and the solid lines represent the average activity over all models. The plot indicates positive neural activity at stimulated and successor locations, and no activity at predecessor locations. (e) Neural activity at all sequence locations in response to each partial sequence. These plots indicate the decaying positive response at successor locations for individual partial sequences, hinting at successor-like representations in response to each partial sequence.

RPL learns abstract sequence representations similar to macaque PFC.

(a) Schematic of the local-global oddball paradigm used by Bellet et al. [44]. Example trials consisting of four sequential stimuli show an “xxxY” and an “xxxx” trial in the context of frequent “xxxY” presentations (top). A given trial can be a local standard or deviant depending on the local sequence structure of the four stimuli and a global standard or deviant depending on the context set for a block of several trials (bottom). Figure adapted from [44]. (b) Comparison of temporal serial position representation in experiment (top) and the model (bottom). Plots show best linearly decoded probabilities of the four serial positions (items 1 through 4) averaged over all xxxx and xxxY trials. RPL learns representations that encode serial position more robustly than the macaque PFC. The experimental data corresponds to Monkey A and has been reproduced from publicly available data and code. (c) Comparison of stimulus identity representation in the experiment (top) and the model (bottom). Plots show projections of neural activity or RPL representations on directions that best separate pairs of input stimuli (Item a vs b) that comprise x and Y, averaged across trials (Methods). RPL representations encode stimulus identity similar to the experiment. (d) Same as (c) but for representation of the global context set by the frequent trial. (e) Same as (c) but for the presence or absence of a local deviant at the fourth sequence position. (f) Same as (c) but for the presence or absence of a global deviant, decoded from the fourth sequence position.

© 2024, Bellet ME. Panel A was reprinted with permission from Figure 1 Bellet et al. 2024 [44], which was published under a CC BY- NC-ND https://creativecommons.org/licenses/by-nc-nd/4.0. Further reproductions must adhere to the terms of this license.

Untangling in a cortical model with local prediction error circuits.

(a) Schematic of the network model consisting of six identical RPL circuits in a hierarchy, each containing a shallow encoder, an integrator and a predictor network. Each circuit computes its own prediction error while Stop Grad (SG) operations prevented gradients from flowing between the individual circuits. In this h-RPL model horizontal intra-laminar connectivity implements the encoder network. In contrast, vertical connectivity accounts for the integrator and predictor circuitry. Cortical anatomy drawing (right) adapted from Cajal [52]. (b) Decoding R2 values and accuracy of the latent variables (cf. Fig. 2b,c) for each RPL circuit in h-RPL (red) and a variant in which only the last circuit in the hierarchy was trained (gray). Shading corresponds to one standard deviation over n = 3 independent models. The solid black line denotes the performance of RPL with a deep encoder (cf. Fig 2j). The h-RPL model progressively learns latent variable representations comparable to end-to-end trained RPL.