Figures and data
data:image/s3,"s3://crabby-images/9e722/9e72216bbafab1730413e77f7bc0997b85f142fb" alt=""
Schematic of the overall framework. Given a task (e.g., an analogy to solve), inputs (denoted as {A, B, C, D}) are represented by grid codes, consisting of units (“grid cells”) representing different combinations of frequencies and phases. Grid embeddings (xA, xB, xC, xD) are multiplied elementwise by a set of learned attention weights w, then passed to a inference module R. The attention weights w are optimized using LDPP, which encourages attention to grid embeddings that maximize the volume of the representational space. The inference module outputs a score for each candidate analogy (consisting of A, B, C and a candidate answer choice D). The scores for all answer choices are passed through a softmax to generate an answer ŷ, which is compared against the target y to generate the task loss Ltask.
data:image/s3,"s3://crabby-images/f343c/f343cc515bd8f4c3b1734005f9d7328f9727c63c" alt=""
Generation of test analogies from training analogies (region marked in blue) by: a) translating both dimension values of A, B, C, D by the same amount; and b) scaling both dimension values of A, B, C, D by the same amount. Since both dimension values are transformed by the same amount, each input gets transformed along the diagonal.
data:image/s3,"s3://crabby-images/4877a/4877aae2623559c8dc77a4aa9a68b28f050882ff" alt=""
Training with DPP-A
data:image/s3,"s3://crabby-images/60d70/60d7045779c13161730daf8fdb6859cf2c8ce900" alt=""
Results on analogy on each region for translation and scaling using LSTM in the inference module.
data:image/s3,"s3://crabby-images/15aa4/15aa46e8335701b2ec0fe009ccbc3a81e85479ab" alt=""
Results on analogy on each region for translation and scaling using transformer in the inference module.
data:image/s3,"s3://crabby-images/fcea9/fcea921fc0c672bdfb931f658a42449629623343" alt=""
Results on arithmetic on each region using LSTM in the inference module.
data:image/s3,"s3://crabby-images/d6fe3/d6fe3b90ecf9c547caa676d715766ca3b6489702" alt=""
Results on arithmetic on each region using transformer in the inference module.
data:image/s3,"s3://crabby-images/e15d1/e15d1b7f5fbf318de5b1be13b1125e5a02f18f8d" alt=""
Results on analogy on each region using DPP-A, an LSTM in the inference module, and different embeddings (grid codes, one-hots, and smoothed one-hots passed through a learned encoder) for translation (left) and scaling (right). Each point is mean accuracy over three networks, and bars show standard error of the mean.
data:image/s3,"s3://crabby-images/9b192/9b19263f265bfd891b23983581abf7f62266fa6c" alt=""
Results on analogy on each region using different embeddings (grid codes, and one-hots or smoothed one-hots with and without an encoder) and an LSTM in the inference module, but without DPP-A, TCN, L1 Regularization, or Dropout for translation (left) and scaling (right).
data:image/s3,"s3://crabby-images/3ed28/3ed281e4caa455d94d0390b791a4d429fc556e04" alt=""
Results on analogy on each region using LSTM in the inference module for choosing top K frequencies with
data:image/s3,"s3://crabby-images/855f9/855f9190275d1f3c77fc5b50384f03ff48dfd938" alt=""
Results on analogy on each region for translation and scaling using transformer in the inference module.
data:image/s3,"s3://crabby-images/c7c53/c7c53cbdb4193ccb9a3ee05ce4f8c14afc304caf" alt=""
Results on arithmetic with different embeddings (with DPP-A) using LSTM in the inference module. Results show mean accuracy on each region averaged over 3 trained networks along with errorbar (standard error of the mean).
data:image/s3,"s3://crabby-images/7e92f/7e92f09150175f29792594f10793eed426930a5d" alt=""
Results on arithmetic with different embeddings (without DPP-A, TCN, L1 Regularization, or Dropout) using LSTM in the inference module. Results show mean accuracy on each region averaged over 3 trained networks along with errorbar (standard error of the mean).
data:image/s3,"s3://crabby-images/f797a/f797ab773b79789b20664e81200b432fb371eeb3" alt=""
Results on arithmetic for increasing number of grid cell frequencies Nf on each region using LSTM in the inference module. Results show mean accuracy on each region averaged over 3 trained networks along with errorbar (standard error of the mean).
data:image/s3,"s3://crabby-images/6cc04/6cc041e7b23cdce9ddc0327d2a800cf657c83e00" alt=""
Results for regression on analogy using LSTM in the inference module. Results show mean squared error on each region averaged over 3 trained networks along with errorbar (standard error of the mean).
data:image/s3,"s3://crabby-images/a7f54/a7f54d38cc49b2a3f3b4197966ffbc71f980523b" alt=""
Results for regression on arithmetic on each region using LSTM in the inference module. Results show mean squared error on each region averaged over 3 trained networks along with errorbar (standard error of the mean).
data:image/s3,"s3://crabby-images/e8f94/e8f94e7197f87553afb6a46f325d9b3a5f9b5e6b" alt=""
Results on analogy for L1 regularization for various λs for translation and scaling using LSTM in the inference module. Results show mean accuracy on each region averaged over 3 trained networks along with errorbar (standard error of the mean).
data:image/s3,"s3://crabby-images/23917/23917829bfaf3e34a92585b62b51e0c06e201e13" alt=""
Results on arithmetic for L1 regularization for various λs using LSTM in the inference module. Results show mean accuracy on each region averaged over 3 trained networks along with errorbar (standard error of the mean).
data:image/s3,"s3://crabby-images/18d34/18d345e9b41ca0fcd0ca8ea6054253829c12ac73" alt=""
Results on analogy for one step DPP-A over the complete grid codes for various λs for translation and scaling using LSTM in the inference module. Results show mean accuracy on each region averaged over 3 trained networks along with errorbar (standard error of the mean).
data:image/s3,"s3://crabby-images/b8dda/b8ddaf3c20919be02cad4388b2675912a192c9d4" alt=""
Results on analogy for one step DPP-A within frequencies for various λs for translation and scaling using LSTM in the inference module. Results show mean accuracy on each region averaged over 3 trained networks along with errorbar (standard error of the mean).