# Todos ## Problem Formulation Tabular data prediction regards a row of the table $T_{i,:}$ as a data sample and a column of the table $T_{:,j}$ as a feature field. An entry of a table $T_{i,j}$ represents the feature value of the $j$-th field for the $i$-th data sample. We focus on tabular data prediction on a single table. While multiple tables can be joined into a single table using various techniques, e.g., Deep Feature Synthesis. ## 1. Column-wise: AIM: Arbitrary Order Interaction Machines ### Methodology **Interaction Component** A beneficial feature interaction set: $I_i=\{x_{i1}, x_{i2},\dots,x_{ik}\}.$ We could use **one layer message passing** to obtain the product feature interactions. $$ \forall x \in I_i, s_{ij} = \mathbf{1}_{x_{ij}>0}, h_{ij} = abs(x_{ij}),\\ m_i = AGG_{x_{ij}\in I_i}\{\log(h_{ij})\},\\ s_i = \left(deg(m_i)-\sum_{j}\{s_{ij}\}\right)\%2,\\ h_i = exp(m_i)\odot s_i. $$ For each input, we generate $N$ feature interaction sets $\mathbf{I}=\{I_1, \dots, I_N\}$. The above process could be done highly parrallel, it could be seen as one layer message passing on a uni-directed graph. If we use linear aggregators, the computation complexity is $O(Ed)$, where $E$ is the number of edges of the constructed graph , which is $E=\sum_{I\in \mathbf{I}}cardinality(I)$. If we use attention-based aggregators, the complexity is $O(Ed^2)$. Another important problem is how we generate the $\mathbf{I}$, which we solve in the pruning component. **Pruning Component** Principle: Sufficiency and Low Redundancy. * Feature-value-wise 1. Regularize via the deep infomax loss and the contrastive loss. *Deep infomax mainly contribute a new loss. It simplifies the computation of mutual information between two variables into computing the KL divergence between the joint distribution and the marginal distribution. In practice, it use a discriminator to tell the difference between the two distributions. ![](https://i.imgur.com/6jQ0SaF.png) 2. Other pruning methods like attention-based, etc. * Field-wise (Lighter) 1. A simple way could be first precomputing the MI between each pair of two fields and then use greedy algorithms to decide the feature combination sets. I will survey and find out more pruning methods. **Overall** DeepFM could be seen as a reduced version of AIM. If we restrict the $\mathbf{I}$ into the pair-wise feature interactions and singles. Therefore, theoretically, we could beat DeepFM given good pruning component. Also, we generate arbitrary order product feature interactions, which is strictly $n$-way instead of $2$-way. **Baselines to compare with**: DeepFM, xDeepFM, AutoFIS, PNN, AFM, etc. **Other**: We can further combine AIM with the multi-rows trick introduced as below. ## 2. Row-wise Motivations: 1. Why we need multi-rows - Previous methods: Make prediction based on target input feature and model parameters. $$p(y*|x,\theta).$$ - Single-row model might not be able to remember the relations between datapoints well. - For example, $\theta$ is trained within the whole data space. While the distribution of features tend to be long-tail, for less-hit features, it could be biased. - In recommendation senario, modeling user behavior sequences proves to be effective. - Our method: Consider direct dependencies between datapoints. $$p(y*|x,D_{train},\theta).$$ Related Work: NPT, Hopular, SAINT, RIM, PET, etc. 2. Why we need retrieval. - Homogeneity is important. - Save source. **Base model**: $f(\cdot)$. **Index Factory**: $R(\cdot)$. **Single row prediction process**: $T_{i,:} \rightarrow f(\cdot)\rightarrow h_i \rightarrow p(\cdot)\rightarrow \hat{y}_i$. **Multi-rows usage**: $$ N(t) = R(T_{t,:}),\\ h_t = f(T_{t,:}),\\ a_{ti} = \frac{exp(h_th_i)}{\sum_{j \in N(t)}exp(h_th_j)},\\ n_t = \sum_{i \in N(t)} a_{ti}\cdot \Phi(y_i),\\ \hat{y}_t = p(h_t, n_t). $$ **Base model**: $y_{base} = \Phi_\theta(X_t).$ **Knowledge patterns**: $\mathbf{K}=\{K_i\}.$ $\forall K_i\in\mathbf{K}$: $$ N_i(t) = K_i(X_t,\theta),\\ \alpha_{tj} = \frac{exp(\mathbf{h_th_j})}{\sum_{j \in N_i(t)}exp(\mathbf{h_th_j})},\\ \mathbf{n^i_t} = \sum_{j \in N_i(t)} \alpha_{tj}\cdot \mathbf{h_j},\\ \mathbf{kn^i_t} = \left[ \sum_{j \in N_i(t)} \alpha_{tj}\cdot \mathbf{h_j}|| \sum_{j \in N_i(t)} \alpha_{tj}\cdot \mathbf{y_j}\right]. $$ Then, $$ \{\alpha_{ti}\} \leftarrow Attn(\{\mathbf{n^i_t}\}, \mathbf{h_t}),\\ \mathbf{kn_t} = \sum_i\alpha_{ti}\cdot\mathbf{kn^i_t},\\ \hat{y_t} = \Phi_\theta(X_t) + MLP(\mathbf{kn_t}). $$ Optimal knowledge patterns: $$max I((\mathbf{k_1, k_2,\dots, k_N});y)-\alpha \sum_{i=1}^N H(\mathbf{k_i}|y)-\beta \sum_{i\neq j}I(\mathbf{k_i};\mathbf{k_j};y).$$ Its approximation: $$max \sum_{i=1}^N I(\mathbf{k_i};y)-\sigma\sum_{i\neq j}I(\mathbf{k_i};\mathbf{k_j}).$$ $$max \sum_i I(\mathbf{k_i};y)=max \frac1N\sum_i (\log D_\theta(\mathbf{k_i}, c^+)+(1-\log D_\theta(\mathbf{k_i}, c^-))),\\ min \sum_{i\neq j}I(\mathbf{k_i};\mathbf{k_j}) = max \frac1N\sum_i \left( \log D_w\left(\mathbf{k_i'},\mathbf{k_i''}\right)+(1-\log D_w(\mathbf{k_i'},\mathbf{k_j''}))\right). $$ *$n_t$ should serve as a residual. Even if we retrieve garbage, it won't degrade the performance. **Benefits brought from GNNs (PET)**: (Just a discussion. Doesn't mean that we need to use it. We could borrow the most useful part and simplify it as a trick.) - Stablize multi-rows distributions. GNNs have good stabalization ability and generalization ability. - FATE offered a proof for this with non-linear GCN. - Since the assumptions of GNNs and LP are the same, we can parallelly use features and labels. Also,LP is stronger than simple attention-based aggregation. (More costy thougth.) Even on weaker homophily graph, LP is strong. ## 上线可能性 1. SimHash替代Attention,作为软过滤的方式: **References**: https://arxiv.org/pdf/2108.04468.pdf https://arxiv.org/pdf/2205.10249.pdf **ETA**: ![](https://i.imgur.com/9s4rCYw.png) **SDIM**: ![](https://i.imgur.com/NO9iblA.png) - 利用simhash后fingerprint的汉明距离检索,计算汉明距离是原子操作O(1),因而检索时间是线性。 - **SimHash/分桶的部分和candidate item无关,因而可以在召回/粗排的时候并发做SimHash,提前缓存好,实际预测时可以直接取,整个流程可以online和end-to-end。** Why ETA & SDIM work: - Feasible for modeling long user behavior sequences. - Different from the two-stage methods, (e.g., SIM), it alleviates the divergent targets & distribution shift problem. ## 其他可以改进的点&论文入手点 1. OOD generalization (分布外泛化)。 - cv领域有用KNN解决OOD的paper,22年的:(主要是用KNN距离来进行detection)https://arxiv.org/pdf/2204.06507.pdf - Online retrieval使用最新embedding,可以缓解distribution shift。 - 在learning methodology上我们也可以设计解决OOD问题。 - Survey: https://arxiv.org/pdf/2108.13624.pdf - 解耦表示学习、因果等等 2. 引入label,从能量模型的角度分析KNN和label对OOD问题带来的好处。 3. 从效率角度可以探索:Collision更少、更能mimic target attention的hash函数。 - Feature Hashing for Large Scale Multitask Learning: https://alex.smola.org/papers/2009/Weinbergeretal09.pdf ## 3. Retrieval ### Sparse Retrieval Retrieve relevant rows in explicit space according to feature cooccurences, e.g., TF-IDF values. **Strength**: Light and accurate for tables with less fields. **Weakness**: When faced with tables with a large number of fields, the sparse retrieval performs poor. Since: 1) It regards that the distances between each two feature values of a feature field are equal. 2) It is vulnerable to garbage columns. - **An easy fix**: Only use the top-K fields that have highest mutual information with labels to retrieve, e.g., use user attribute columns and item id columns. Some evidences: ![](https://i.imgur.com/foUAUmc.png) ### Dense Retrieval Retrieve relevant rows in latent space according to cosine distances between feature vectors. **Strength**: Given expressive representations, it could * be aware of the ''real'' distance of different feature values, * apply different term weights to different fields, * more robust to garbage columns. **Weakness**: Rely on expressive representations. * I tried on pure dense retrieval on dota2games, phillipines, etc. (Tables that have hundreds of columns). And it retrieved nearly garbage. * Ray continued with sparse + dense retrieval on Taoqi datasets. And it got good results. **Industial applications**: Approximate KNNs. Some open tools: - Python binding: Faiss, N2. - C++ based: OpenSearch. ### Other thoughts Retrieval itself relies on good cross-rows dependencies. Retrieval helps well on sequential tasks. Temporal information could be important in retrieval. * Retrieve from histories. Since we could only see the past. * Retrieve from the whole data pool. Larger retrieval space, but may suffer from the inductive bias. Since in real world we could only see the past. ## Action Terms 1. Prepare for the interview. 2. Try on AIM. I have constructed the pruning process. Should be fast to try. 3. Try on multi-rows tricks. 4. Need more surveys. Some chitchat: My VPN collapses recently. Very incovenient :( , any recommendation?