Properties of neural trajectories in motor cortex, illustrated for data recorded during the cycling task (MC_Cycle dataset).

a) Low tangling implies separation between trajectories that might otherwise be close. Top. Neural trajectories for forward (purple) and backward (orange) cycling. Trajectories begin 600 ms before movement onset and end 600 ms after movement offset. Trajectories are based on trial-averaged firing rates, projected onto two dimensions. Dimensions were hand-selected to highlight apparent trajectory crossings while also capturing considerable variance (11.5%, comparable to the 11.6% captured by PCs 3 and 4). Gray region highlights one set of apparent crossings. Bottom. Trajectories during the restricted range of times in the gray region, but projected onto different dimensions: the top two principal components when PCA was applied to this subset of data. The same scale is used in top and bottom subpanels. b) Examples of similar behavioral states (inset) corresponding to well-separated neural states (main panel). Colored trajectory tails indicate the previous 150 ms of states. Data from 7-cycle forward condition. c) Joint distribution of pairwise distances for muscle and neural trajectories. Analysis considered all states across all conditions. For both muscle and neural trajectories, we computed all possible pairwise distances across all states. Each muscle state has a corresponding neural state, from the same time within the same condition. Thus, each pairwise muscle-state distance has a corresponding neural-state distance. The color of each pixel indicates how common it was to observe a particular combination of muscle-state distance and neural-state distance. Muscle trajectories are based on seven z-scored intramuscular EMG recordings. Correspondence between neural and muscle state pairs included a 50 ms lag to account for physiological latency. Results were not sensitive to the presence or size of this lag. Neural and muscle distances were normalized (separately) by average pairwise distance. d) Same analysis for neural and kinematic distances (based on phase and angular velocity). Correspondence between neural and kinematic state pairs included a 150 ms lag. Results were not sensitive to the presence or size of this lag. e) Control analysis to assess the impact of sampling error. If two sets of trajectories (e.g. neural and kinematic) are isometric and can be estimated perfectly, their joint distribution should fall along the diagonal. To estimate the impact of sampling error, we repeated the above analysis comparing neural distances across two data partitions, each containing 15-18 trials/condition.

Example training (top panel) and decoding (bottom panels) procedures for MINT, illustrated using four conditions from a reaching task.

a) Libraries of neural and behavioral trajectories are learned such that each neural state corresponds to a behavioral state . b) Spiking observations are binned spike counts (20 ms bins for all main analyses). contains the spike count of each neuron for the present bin, t′, and τ ′ bins in the past. c) At each time step during decoding, the log-likelihood of observing is computed for each and every state in the library of neural trajectories. Log-likelihoods decompose across neurons and time bins into Poisson log-likelihoods that can be queried from a precomputed lookup table. A recursive procedure (not depicted) further improves computational efficiency. Despite utilizing binned spiking observations, log-likelihoods can be updated at millisecond resolution (Methods). d) Two candidate neural states (purple and blue) are identified. The first is the state within the library of trajectories that maximizes the log-likelihood of the observed spikes. The second state similarly maximizes that log-likelihood, with the restriction that the second state must not come from the same trajectory as the first (i.e. must be from a different condition). Interpolation identifies an intermediate state that maximizes log-likelihood. e) The optimal interpolation is applied to candidate neural states – yielding the final neural-state estimate – and their corresponding behavioral states – yielding decoded behavior.

Examples of behavioral decoding provided by MINT for one simulated dataset and four empirically recorded datasets.

All decoding is causal; only spikes from before the decoded moment are used. a) MINT was applied to spiking data from an artificial spiking network. That network was trained to generate posterior deltoid activity and to switch between reaching and cycling tasks. Based on spiking observations, MINT approximately decoded the true network output at each moment. ‘R3’, ‘R4’, and ‘R6’ indicate three different reach conditions. ‘C’ indicates a cycling bout. MINT used no explicit task-switching, but simply tracked neural trajectories across tasks as if they were conditions within a task. b) Illustration of the challenging nature, from a decoding perspective, of network trajectories. Trajectories are shown for two dimensions that are strongly occupied during cycling. Trajectories for the 8 reaching conditions (pink) are all nearly orthogonal to the trajectory for cycling (brown) and thus appear compressed in this projection. c) Decoded behavioral variables (green) compared to actual behavioral variables (black) across four empirical datasets. MC_Cycle and MC_RTT show 10 seconds of continuous decoding. MC_Maze and Area2_Bump show randomly selected trials, demarcated by vertical dashed lines.

Comparison of decoding performance, for MINT and four additional algorithms, for four datasets and multiple decoded variables.

On a given trial, MINT decodes all behavioral variables in a unified way based on the same inferred neural state. For non-MINT algorithms, separate decoders were trained for each behavioral group. E.g. separate GRUs were trained to output ‘position’ and ‘velocity’ in MC_Maze. Parentheticals indicate the number of behavioral variables within a group. E.g. ‘position (2)’ has two components: x- and y-position. R2 is averaged across behavioral variables within a group. ‘Overall’ plots performance averaged across all behavioral groups. R2 for feedforward networks and GRUs are additionally averaged across runs for 10 random seeds. The Kalman filter is traditionally utilized for position- and velocity-based decoding and was therefore only used to predict these behavioral groups. Accordingly, the ‘overall’ category excludes the Kalman filter for datasets in which the Kalman filter did not contribute predictions for every behavioral group. Vertical offsets and vertical ticks are used to increase visibility of data.

Evaluation of neural state estimates for seven datasets.

a) Performance quantified using bits per spike. The benchmark’s baseline methods have colored markers. All other submissions have gray markers. Vertical offsets and vertical ticks are used to increase visibility of data. Results with negative values are designated by markers to the left of zero, but their locations don’t reflect the magnitude of the negative values. Data from Neural Latents Benchmark (https://neurallatents.github.io/). b) Performance quantified using PSTH R2. c) Performance quantified using velocity R2, after velocity was decoded linearly from the neural state estimate. A linear decode is used simply as a way of evaluating the quality of neural state estimates, especially in dimensions relevant to behavior.

MINT training times and average execution times (average time it took to decode a 20 ms bin of spiking observations).

To be appropriate for real-time applications, execution times must be shorter than the bin width. Note that the 20 ms bin width does not prevent MINT from decoding every millisecond; MINT updates the inferred neural state and associated behavioral state between bins, using neural and behavioral trajectories that are sampled every millisecond. For all table entries (except MC_RTT training time), means and standard deviations were computed across 10 train/test runs for each dataset. Training times exclude loading datasets into memory and any hyperparameter optimization. Timing measurements taken on a Macbook Pro (on CPU) with 32GB RAM and a 2.3 GHz 8-Core Intel Core i9 processor. For MC_RTT, training involved running AutoLFADS twice (and averaging the resulting rates) to generate neural trajectories. This training procedure utilized 10 GPUs and took 1.6 hours per run. For AutoLFADS, hyperparameter optimization and model fitting procedures are intertwined. Thus, the training time reported includes time spent optimizing hyperparameters.

Decoding robustness in the face of small training sets and neuron counts.

a) R2 values for MINT and two neural network decoders (GRU and the feedforward network) when decoding position and velocity from four maze datasets with progressively fewer training trials per condition. b) R2 values for MINT and the Wiener filter when decoding position and velocity from MC_Maze-L with progressively fewer neurons. ‘Retrained’ results (solid lines) show mean performance (across 50 random drops at each neuron count) when the decoders are trained and tested on reduced neuron sets. ‘Zeroed’ results (dashed lines) show mean performance when the decoders are trained on all 162 neurons and, during testing, the ‘dropped’ neurons’ spike counts are set to zero without retraining the decoder. Shaded regions depict standard errors across repeated droppings with different neurons.

Neuron, condition, and trial counts for each dataset. For some datasets, there are additional trials that are excluded from these counts. These trials are excluded because they were only usable for Neural Latents Benchmark submissions due to hidden behavioral data and partially hidden spiking data (see Neural Latents Benchmark).

Hyperparameters for learning neural trajectories via standard trial-averaging. σ is the standard deviation of the Gaussian kernel used to temporally filter spikes. Trial-averaging Type I and Type II procedures are described in Averaging across trials. Dneural and Dcondition are the neural- and condition-dimensionalities described in Smoothing across neurons and/or conditions. ‘Full’ means that no dimensionality reduction was performed. Condition smoothing could not be performed for MC_Cycle or the multitask network because different conditions in these datasets are of different lengths (i.e. Kc is not the same for all c). (C) and (R) refer to the cycling and reaching trajectories, respectively, in the multitask network. tmove, tstop, and tgo correspond to movement onset, movement offset, and the ‘go’ time in the ready-set-go task, respectively.

Details for decoding analyses. The number of training and testing trials for each dataset are provided along with the evaluation period over which performance was computed. The window lengths refer to the amount of spiking history MINT used for decoding (e.g. when τ′ = 14 and Δ = 20, the window length is (τ′+ 1)Δ = 300 ms). tmove corresponds to movement onset, tstop corresponds to movement offset, tstart refers to the beginning of a trial, and tend refers to the end of a trial. There is no defined condition structure for MC_RTT to use for defining trial boundaries. Thus, each trial is simply a 600 ms segment of data, with no alignment to movement. Although 270 of these segments were available for testing, the first 2 segments lacked sufficient spiking history for all decoders to be evaluated and were therefore excluded, leaving 268 test trials. For the multitask network, performance was evaluated on a continuous stretch of 135 trials spanning 7.1 minutes.

Details for neural state estimation results. Note that the training trial counts match the total number of trials reported in Table 2. This reflects that the Neural Latents Benchmark utilized an additional set of test trials not reflected in the Table 2 trial counts. The test trials used for this analysis have ground truth behavior hidden by the benchmark creators and are therefore only suitable for this analysis.

MINT’s decoding performance is robust to the choice of hyperparameters.

MINT was run on the MC_Maze dataset with systematic perturbations to two of MINT’s hyperparameters: bin width (ms) and window length (ms). Bin width is the size of the bin in which spikes are counted: 20 ms for all analyses in the main figures. Window length is the length, in milliseconds, of spiking history that is considered. For all main analyses of the MC_Maze dataset, this was 300 ms, i.e. MINT considered the spike count in the present bin and in 14 previous bins. Perturbations were also made to two hyperparameters related to learning neural trajectories: temporal smoothing width (standard deviation of Gaussian filter applied to spikes) and condition-smoothing dimensionality (see Methods). These two hyperparameters describe how aggressively the trial-averaged data are smoothed (across time and conditions, respectively) when learning rates. Baseline decoding performance (black circles) was computed using the same hyperparameters that were used with the MC_Maze dataset in the analyses from Figure 3 and Figure 4. Then, decoding performance was computed by perturbing each of the four hyperparameters twice (colored circles): once to 50% of the hyperparameter’s baseline value and once to 150%. Trials were bootstrapped (1000 resamples) to generate 95% confidence intervals (error bars). Perturbations of hyperparameters had little impact on performance. Altering bin width had essentially no impact, nor did altering temporal smoothing. Shortening window length had a negative impact, presumably because MINT had to estimate the neural state using fewer observations. However, the drop in performance was minimal: R2 dropped by .011 for position decoding only. Reducing the number of dimensions used for across-condition smoothing, and consequently over-smoothing the data, had a negative impact on both position and velocity decoding. Yet again this was small: e.g. velocity R2 dropped by .010. These results demonstrate that MINT can achieve high performance using hyperparameter values that span a large range. Thus, they do not need to be meticulously optimized to ensure good performance. In general, optimization may not be needed at all, as MINT’s hyperparameters can often be set based on first principles. For example, in this study, bin width was never optimized either for specific datasets or in general. We chose to always count spikes in 20 ms bins (except in the perturbations shown here) because this duration is long enough to reduce computation time yet short relative to the timescales over which rates change. Additionally, window length can be optimized (as we did for decoding analyses), but it could also simply be chosen to roughly match the timescale over which past behavior tends to predict future behavior. Temporal smoothing of trajectories when building the library can simply use the same values commonly used when analyzing such data. For example, in prior studies, we have used smoothing kernels of width 20 to 30 ms when computing trial-averaged rates, and these values also support excellent decoding. Condition smoothing is optional and need not be applied at all, but may be useful if one wishes to record fewer trials for more conditions. For example, rather than record 15 trials for 8 reach directions, one might wish to record 5 trials for 24 conditions, then use condition smoothing to reduce sampling error.

Illustration of why decoding will typically fail to generalize across tasks when neural trajectories occupy orthogonal subspaces.

In theory, learning the output-potent dimensions in motor cortex would be an effective strategy for biomimetic decoding that generalizes to novel tasks. In practice, it is difficult (and often impossible) to statistically infer these dimensions without observing the subject’s full behavioral repertoire (at which point generalization is no longer needed because all the relevant behaviors were directly observed). a) In this toy example, two neural trajectories occupy fully orthogonal subspaces. In the ‘blue task’, the neural trajectory occupies dimensions 1 and 2. In the ‘red task’, the neural trajectory occupies dimensions 3 and 4. Trajectories in both subspaces have a non-zero projection onto wout, enabling neural activity from each task to drive the appropriate output. The output at right is simply the projection of the blue-task and red-task trajectories onto wout. b) Illustration of the difficulty of inferring wout from only one task. In this example, wout is learned using data from the blue task, by linearly regressing the observed output against the neural trajectory for the blue task. The resulting estimate of ŵ 1 correctly translates neural activity in the blue task into the appropriate time-varying output, but fails to generalize to the red task. This failure to generalize is a straightforward consequence of the fact that ŵ 1 was learned from data that didn’t explore dimensions 3 and 4. Note that this same phenomenon would occur if ŵ1 were estimated by regressing intended output versus neural activity (as might occur in a paralyzed patient) c) Estimating the output-potent dimension based only on the red task yields an estimate ŵ 2 that fails to generalize to the blue task. This phenomenon is illustrated here for a linear readout, but would apply to most nonlinear methods as well, unless some other form of knowledge can allow interpretation of neural trajectories in previously unseen dimensions.

Impact of different modeling and preprocessing choices on performance of MINT.

For most applications we anticipate MINT will employ direct decoding that leverages the correspondence between neural and behavioral trajectories, but one could also choose to linearly decode behavior from the estimated neural state. For most applications, we anticipate MINT will use interpolation amongst candidate neural states on neighboring trajectories, but one could also restrict decoding to states within the trajectory library. For real-time applications we anticipate MINT will be run causally, but acausal decoding (using both past and future spiking observations) could be used offline or even online by introducing a lag. We anticipate MINT may be used both in situations where spike events have been sorted based on neuron identity, and situations where decoding simply uses channel-specific unsorted threshold crossings. Panels a-c explore the first three choices. Performance was quantified for all 8 combinations of: direct MINT readout vs. linear MINT readout, interpolation vs. no interpolation, causal decoding vs. acausal decoding. This was done for 121 behavioral variables across 4 datasets for a total of 964 R2 values. The ‘phase’ behavioral variable in MC_Cycle was excluded from ‘linear MINT readout’ variants because its circularity makes it a trivially poor target for linear decoding. a) MINT’s direct neural-to-behavioral-state association outperforms a linear readout based on MINT’s estimated neural state. Performance was significantly higher using the direct readout (ΔR2 = .061 ± .002 SE; p<.001, paired t-test). Note that the linear readout still benefits from the ability of MINT to estimate the neural state using all neural dimensions, not just those that correlate with kinematics. b) Decoding with interpolation significantly outperformed decoding without interpolation (ΔR2 = .018 ± .001 SE; p<.001, paired t-test). c) Running acausally significantly improved performance relative to causal decoding (ΔR2 = .051 ± .002 SE; p<.001, paired t-test). Although causal decoding is required for real-time applications, this result suggests that (when tolerable) introducing a small decoding lag could improve performance. For example, a decoder using 200 ms of spiking activity could introduce a 50 ms lag such that the decode at time t is rendered by generating the best estimate of the behavioral state at time t − 50 using spiking data from t − 200 through t. d) Decoding performance for 13 behavioral variables in the MC_Cycle dataset when sorted spikes were used (112 neurons, pink) versus ‘good’ threshold crossings from electrodes for which the signal-to-noise ratio of the firing rates exceeded a threshold (93 electrodes, SNR > 2, cyan). The loss in performance when using threshold crossings was small (ΔR2 = -.014 ± .002 SE). SE refers to standard error of the mean.

The Kalman filter’s relative performance would improve if neural data had different statistical properties.

We compared MINT to the Kalman filter across one empirical dataset (MC_Maze) and four simulated datasets (same behavior as MC_Maze, but simulated spikes). Simulated firing rates were linear functions of hand position, velocity, and acceleration (as is assumed by the Kalman filter). The means and standard deviations (across time and conditions) of the simulated firing rates were matched to actual neural data (with additional rate scaling in two cases). The Kalman filter assumes that observation noise is stationary and Gaussian. Although spiking variability cannot be Gaussian (spike counts must be positive integers), spiking variability can be made more stationary by letting that variability depend less on rate. Thus, although the first simulation generated spikes via a Poisson process, subsequent simulations utilized gamma-interval spiking (gamma distribution with α = 2 and β = 2λ, where λ is firing rate). Gamma-interval spiking variability of this form is closer to stationary at higher rates. Thus, the third and fourth simulations scaled up firing rates to further push spiking variability into a more stationary regime at the expense of highly unrealistic rates (in the fourth simulation, a firing rate briefly exceeded 1800 Hz). Overall, as the simulated neural data better accorded with the assumptions of the Kalman filter, decoding performance for the Kalman filter improved. Interestingly, MINT continued to perform well even on the simulated data, likely because MINT can exploit a linear relationship between neural and behavioral states when the data argue for it and higher rates benefit both algorithms. These results demonstrate that algorithms like MINT and the Kalman filter are not intrinsically good or bad. Rather, they are best suited to data that match their assumptions. When simulated data approximate the assumptions of the Kalman filter, both methods perform similarly. However, MINT shows much higher performance for the empirical data, suggesting that its assumptions are a better match for the statistical properties of the data. decoded behavior

MINT is a modular algorithm for which a variety of modifications and extensions exist.

The flowchart illustrates the standard MINT algorithm in black and lists potential changes to the algorithm in red. For example, the library of neural trajectories is typically learned via standard trial-averaging or a single-trial rate estimation technique like AutoLFADS. However, transfer learning could be utilized to learn the library of trajectories based on trajectories from a different subject or session. The trajectories could also be modified online while MINT is running to reflect changes in spiking activity that relate to recording instabilities. A potential modification to the method also occurs when likelihoods are converted into posterior probabilities. Typically, we assume a uniform prior over states in the library. However, that prior could be set to reflect the relative frequency of different behaviors and could even incorporate time-varying information from external sensors (e.g. state of a prosthetic limb, eye tracking, etc.) that include information about how probable each behavior is at a given moment. Another potential extension of MINT occurs at the stage where candidate neural states are selected. Typically, these states are selected based solely on spike count likelihoods. However, one could use a utility function that reflects a user’s values (e.g. the user may dislike some decoding mistakes more than others) in conjunction with the likelihoods to maximize expected utility. Lastly, the behavioral estimate that MINT returns could be post-processed (e.g. temporally smoothed). MINT’s modularity largely derives from the fact that the library of neural trajectories is finite. This assumption enables posterior probabilities to be directly computed for each state, rather than analytically derived. Thus, choices like how to learn the library of trajectories, which observation model to use (e.g. Poisson vs. generalized Poisson), and which (if any) state priors to use, can be made independently from one another. These choices all interact to impact performance. Yet they will not interact to impact tractability, as would have been the case if one analytically derived a continuous posterior probability distribution.

Video demonstrating causal neural state estimation and behavioral decoding from MINT on the MC_Cycle dataset (see supplementary file).

In this dataset, a monkey moved a hand pedal cyclically forward or backward in response to visual cues. The raster plot of spiking activity (112 neurons, bottom right subpanel) and the actual and decoded angular velocities of the pedal (top right subpanel) are animated with 10 seconds of trailing history. Decoding was causal; the decode of the present angular velocity (right hand edge of scrolling traces) was based only on present and past spiking. Cycling speed was 2 Hz. The underlying neural state estimate (green sphere in left subpanel) is plotted in a 3D neural subspace, with a 2D projection below. The present neural state is superimposed on top of the library of 8 neural trajectories used by MINT. The state estimate always remained on or near (via state interpolation) the neural trajectories. Purple and orange trajectories correspond to forward and backward pedaling conditions, respectively. The lighter-to-darker color gradients differentiate between trajectories corresponding to 1-, 2-, 4-, and 7-cycle conditions for each pedaling direction. Neural state estimate corresponds in time to the right edge of the scrolling raster/velocity plot. The three-dimensional neural subspace was hand-selected to capture a large amount of neural variance (59.6%; close to the 62.9% captured by the top 3 PCs) while highlighting the dominant translational and rotational structure in the trajectories.

Hyperparameters used for the Wiener filter. The L2 regularization term λ was optimized in the range [0, 2000], with the optimized values rounded to the closest multiple of 10. Window lengths were optimized (in 20 ms increments) in the range [200, 600] for Area2_Bump, [200, 1000] for MC_Cycle, [200, 1200] for MC_RTT, and [200, 700] for MC_Maze, MC_Maze-L, MC_Maze-M, and MC_Maze-S. These ranges were determined by the structure of each dataset (e.g. Area2_Bump couldn’t look back more than 600 ms from the beginning of the evaluation epoch without entering the previous trial). Window lengths are directly related to κ via Δ (e.g. κ = 14 would correspond to a window length of Δ(κ + 1) = 20(14 + 1) = 300 ms.)

Hyperparameters used for the Kalman filter. The lag (in increments of 20 ms time bins) between neural activity and behavior was optimized in the range [2, 8], corresponding to 40-160 ms, for all datasets except Area2_Bump. For Area2_Bump the lag was not optimized and was simply set to 0 due to the fact that, in a sensory area, movement precedes sensory feedback. Given that xk aggregates spikes across the whole time bin, but yk corresponds to the behavioral variables at the end of the time bin, the effective lag is actually half a bin (10 ms) longer — i.e. the effective range of lags considered for the non-sensory datasets was 50-170 ms.

Hyperparameters used for the feedforward neural network. The number of hidden layers (L) was optimized in the range [1, 15]. The number of units per hidden layer (D) was optimized in the range [50, 1000], with the optimized values rounded to the closest multiple of 10. The dropout rate was optimized in the range [0, 0.5] and the number of training epochs was optimized in the range [2, 100]. Window lengths were optimized (in 20 ms increments) in the range [200, 600] for Area2_Bump, [200, 1000] for MC_Cycle, [200, 1200] for MC_RTT, and [200, 700] for MC_Maze, MC_Maze-L, MC_Maze-M, and MC_Maze-S.

Hyperparameters used for the GRU network. The number of units (D) was optimized in the range [50, 1000], with the optimized values rounded to the closest multiple of 10. The dropout rate was optimized in the range [0, 0.5] and the number of training epochs was optimized in the range [2, 50]. Window lengths were optimized (in 20 ms increments) in the range [200, 600] for Area2_Bump, [200, 1000] for MC_Cycle, [200, 1200] for MC_RTT, and [200, 700] for MC_Maze, MC_Maze-L, MC_Maze-M, and MC_Maze-S.