# Implementing Bluesky Recommendation Models using Graph Neural Network ## Background and Motivation Following interesting accounts is the most important thing to enjoy online social media, such as Twitter, Facebook, and Bluesky. This is particularly important in Bluesky as it is an early-stage product, and it is not easy to find such accounts. In this project, I implemented a recommendation model for Bluesky. You can play the model on https://bluesky-recs.raven-sun.ts.net. This note explains its technical overview. --- ![Screenshot 2024-01-28 at 14.56.00](https://hackmd.io/_uploads/H1bSjkEcT.png) **Figure 1**. User recommendation for my account. The listed users have similar traits to me (Japanese, Software Enginneer or Researcher). ![Screenshot 2024-01-28 at 14.55.49](https://hackmd.io/_uploads/BkHSjJE96.png) **Figure 2**. Post recommendation for my account. The listed posts are mainly from users with similar traits. --- ## Link-Based Recommendation There are two types of recommendation algorithms. One is the content-based recommendation, which exploits information about the contents (texts, images, and videos) of the entities. The other is the link-based recommendation, which only uses the relationship (followed, liked, etc) between the entities. The **content-based approach** is the standard approach to establish recommendation models. As of 2024, many pre-trained content models, such as language models and image recognition models, are freely available; hence, we can combine them to establish our recommendation model. However, it has one concern for the social media recommendation context. Social media usually involve many people from different spaces (country, genre, etc), and all existing content models can understand contents from a single space. Therefore, this approach yields a recommendation model that only works well in a particular domain. This is particularly problematic for me as the first user of my model since my native language is Japanese and the quality of language models on Japanese is far worse than that on English. The **link-based approach** could mitigate the above-mentioned problem. This approach doesn't understand the contents, but tries to utilise the relationship between the entities. The model works well in all spaces, as the relations are (usually) universal in all spaces. The drawback is that, as the relations are weaker signal than the contents, we need larger training data with an advanced training strategy for the link-based models. In this project, I use a **graph neural network (GNN)** for the link-based recommendation, as I'm a fan of GNN technology. A GNN takes a graph as an input and produces embeddings of nodes as outputs. By feeding the Bluesky social network, which consists of the follow-followed relations of users, we obtain the embeddings of users. Intuitively, if two users have similar embeddings, then they have similar traits (languate, social status, interests, etc) due to the nature of the social network. We can use the embeddings for various tasks, including user and post embeddings. Below, I explain how I implemented a link-based recommendation by a GNN. # User Embedding The first task is to compute user embeddings using a GNN. ## Data Collection We collect the Bluesky social network. I implemented a crawler in Python with the [atproto package](https://pypi.org/project/atproto/) and executed on my home Kubernetes cluster. The run was managed by [Prefect](https://www.prefect.io/). The collected data is stored as a table in my home [MariaDB](https://mariadb.org/) server, where the table schema is `follows: (source_id: STRING, target_id: STRING)`, representing that `source_id` follows `target_id`. There is other option storing the data on graph database like [neo4j](https://neo4j.com/); however, in my experience, storing data in an RDB is often more flexible. Currently, there are 1.5M users and 35M follow relations. ## Model Now we implement a GNN model to compute the node embeddings. As we are implementing a link-based model, we don't have any content-related features. Instead, we assign *ID features* to nodes. The ID feature is a hashed value of the node ID, which takes an integer value in $0, \dots, 2^{32} - 1$. Using the ID features allows us to differentiate the nodes, at the cost of losing the generalisability to unknown nodes (i.e., the model becomes *transductive*). In our application, we are only interested in the prediction in the Bluesky graph; hence, this is not a problem. The code for this part is explained below. To process the ID features in a GNN, we embed the IDs in the $d$-dimensional latent space. PyTorch has `torch.nn.Embedding` for this purpose; however, it is not applicable in our case because the input has too large domain ($2^{32}$). Instead, we use the **[QR embedding](https://gowrishankar.info/blog/quotient-remainder-embedding-dealing-with-categorical-data-in-a-dlrm-a-paper-review/)**. The QR embedding takes a modulo $M$, which is typically $2^{16}$, computes the embeddings of `xs // M` and `xs % M` using the `torch.nn.Embedding`, and outputs the product of these embeddings. The code is given as follows. ```python= class QREmbedding(torch.nn.Module): def __init__(self, output_dim: int): super().__init__() self.modulo = 2**16 self.q_embedding = torch.nn.Embedding(self.modulo, output_dim) self.r_embedding = torch.nn.Embedding(self.modulo, output_dim) def forward(self, xs: torch.Tensor): q_embedding = self.q_embedding(xs // self.modulo) r_embedding = self.r_embedding(xs % self.modulo) return q_embedding * r_embedding ``` Now we are implementing the GNN part. Although there are several GNN frameworks, such as [PyG](https://pytorch-geometric.readthedocs.io/en/latest/) and [DGL](https://www.dgl.ai/), we implement it with pure PyTorch as it's sufficiently simple. Recall that our graph has two relations, "follow" and "followed". So, we compute the average latent vectors over these two relations. By concatenating with the original latent vectors, we obtain $3d$ vectors for the nodes. The GNN embedding is then obtained by putting the $3d$ vector to a neural network. Note that this model is essentially the heterogeneous version of the [Graph Isomorphism Network (GIN)](https://arxiv.org/abs/1810.00826). ```python= class Gnn(torch.nn.Module): def __init__(self, latent_dim): super().__init__() self.nn = torch.nn.Sequential( torch.nn.Linear(3 * latent_dim, latent_dim), torch.nn.ReLU(), torch.nn.Linear(latent_dim, latent_dim), ) def forward( self, xs: torch.Tensor, edge_index: Dict[str, torch.Tensor] ) -> torch.Tensor: zs = [xs] for edge_type, es in edge_index.items(): zs.append(torch.zeros_like(xs).index_add_(0, es[0], xs[es[1]])) return self.nn(torch.concat(zs, axis=1)) ``` The resulting model is then given as follows. ```python= class BlueskyUserModel(torch.nn.Module): def __init__(self, latent_dim): super().__init__() self.embedding = QREmbedding(latent_dim) self.gnn = Gnn(latent_dim) self.accuracy = 0.0 def forward( self, xs: torch.Tensor, edge_index: Dict[str, torch.Tensor], ): xs = self.embedding(xs) xs = self.gnn(xs, edge_index) return xs ``` We train the above model by solving the **link-prediction** task, i.e., it tries to minimise the distance of the embeddings of two nodes if they are in the follow-followed relation. We explain this part later because the actual implementation depends on the dataset materialisation below. ## Dataset Materialisation We convert the collected data to a PyTorch-readable format. This process is called the *dataset materialisation*. I believe this is the most important part in GNN implementation. The trivial approach materialises the whole graph by dumping the table as a PyTorch tensor. However, it doesn't work well in our case due to the **memory consumption** --- a GNN model consumes memory proportional to the input graph size. Hence, it doesn't fit to standard computers if we use the whole Bluesky network. To mitigate this issue, we perform *neighbour sampling*. We observe that, in our GNN model, the embedding of a node is computable from its $1$-hop neighbourhood information. Hence, instead of materialising the whole graph, we sample edges to be predicted, take their $1$-hop neighbours, and materialise the subgraph. In this approach, we can control the memory consumption by the edge sample rate. This sampling procedure is easily implemented by a SQL as follows. Here, in addition to the neighbour sampling, we also apply degree pruning, where we only take `max_degree` neighbours on average. ```sql= WITH follows_with_rands AS ( SELECT source_id, target_id, out_degree, in_degree, CRC32(CONCAT('target_{name}', source_id, target_id)) / POW(2, 32) AS r_target, CRC32(CONCAT('edge_{name}', source_id, target_id)) / POW(2, 32) AS r_edge FROM follows_with_degrees ), follows_processed AS ( SELECT source_id, target_id, out_degree, in_degree, (r_target < {edge_sample_rate}) AS is_target FROM follows_with_rands WHERE r_edge * in_degree <= {max_degree} AND r_edge * out_degree <= {max_degree} ), seed_nodes AS ( SELECT source_id AS node_id FROM follows_processed WHERE is_target UNION DISTINCT SELECT target_id AS node_id FROM follows_processed WHERE is_target ) SELECT source_id, target_id, is_target FROM follows_processed WHERE source_id IN (SELECT * FROM seed_nodes) OR target_id IN (SELECT * FROM seed_nodes) ``` Now we discuss the in-memory format. It consists of the graph structure, (`xs` and `edge_index`), edges to be predicted (`label_index`), and the mapping between the node ID and the node index (`node_ids`). ```python @dataclass class Graph: xs: torch.Tensor edge_index: Dict[str, torch.Tensor] label_index: torch.Tensor node_ids: np.array @classmethod def from_df(cls, edge_df): # assign index source_df = edge_df[["source_id", "is_target"]].rename( columns={"source_id": "node_id"} ) target_df = edge_df[["target_id", "is_target"]].rename( columns={"target_id": "node_id"} ) node_df = ( pd.concat((source_df, target_df)).groupby("node_id").max().reset_index() ) node_df["index"] = node_df.index node_ids = node_df["node_id"].values xs = torch.from_numpy(np.vectorize(hashing)(node_ids)) edge_index_df = edge_df.merge( node_df, left_on="source_id", right_on="node_id", ).merge( node_df, left_on="target_id", right_on="node_id", ) label_index = torch.from_numpy( edge_index_df[edge_index_df["is_target"] == 1][ ["index_x", "index_y"] ].values.T ) follow = torch.from_numpy( # edge_index_df[edge_index_df["is_target"] == 0][ edge_index_df[ ["index_x", "index_y"] ].values.T ) edge_index = {"follow": follow, "followed": follow[[1, 0], :]} return cls( xs=xs, edge_index=edge_index, label_index=label_index, node_ids=node_ids, ) ``` We put the materialised data to a locally hosted S3-compatible bucket (MinIO). ```python= for i in range(128): name = f"{i:03}" edge_df = pd.read_sql_query(sql, engine) # SQL is the above graph = Graph.from_df(edge_df) path = f"bluesky/dataset/{name}.pkl" fs = fsspec.filesystem("s3", **s3_options) with fs.open(path, "wb", **s3_options) as fp: p = cloudpickle.dumps(graph) fp.write(p) ``` ## Training Now we are implementing the training procedure. We use the [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/). Our goal is to make the embeddings of two nodes closer if they have edges. For this purpose, we train our model using the [`TripletMarginLoss`](https://pytorch.org/docs/stable/generated/torch.nn.TripletMarginLoss.html). This loss takes three vectors, `anchor`, `positive`, and `nevative`, and enforces that `positive` is closer to `anchor` than `negative`. We first compute the embeddings using the proposed model. Then, we use the sampled edges as positive examples. Here, we set the source side as the `anchor` and the target side as the `positive`. To obtain `negative`, we shuffle the target side; see the code below. This approach is called the **[in-batch negative sampling](https://arxiv.org/abs/1511.06939)**. ```python= class BlueskyEmbeddingTask(pl.LightningModule): def __init__(self, model: torch.nn.Module): super().__init__() self.model = model self.loss_fn = torch.nn.TripletMarginLoss(1.0) def training_step(self, batch): graph = batch[0] xs = self.model(graph.xs, graph.edge_index) label_index = graph.label_index pi = torch.randperm(len(label_index[1])) anchor = xs[label_index[0]] positive = xs[label_index[1]] negative = xs[label_index[1][pi]] loss = self.loss_fn(anchor, positive, negative) accuracy = ( ( (anchor - positive).pow(2).sum(axis=1) < (anchor - negative).pow(2).sum(axis=1) ) .float() .mean() ) self.accuracy = accuracy self.log("loss", loss, prog_bar=True) self.log("accuracy", accuracy, prog_bar=True) return loss def configure_optimizers(self): return torch.optim.Adam(self.parameters()) ``` We construct the data loader using [TorchData](https://pytorch.org/data/beta/index.html). TorchData supports storage with FSSpec support. Hence, we can simply implement the dataloader as follows. ```python= dp = ( IterableWrapper(["s3://bluesky/dataset/"]) .list_files_by_fsspec(**s3_options) .open_files_by_fsspec("rb", **s3_options) .map(decoder) ) dataloader = DataLoader2(dp) ``` The whole training procedure is then given as follows. ```python= model = BlueskyUserModel(latent_dim) task = BlueskyEmbeddingTask(model) trainer = pl.Trainer(max_epochs=max_epochs) trainer.fit(task, train_dataloaders=[dataloader]) ``` ## Inference After completing the training, we compute the user embeddings. There are two options. One is to predict all nodes in the batch. The other one is to predict all nodes that have whole neighbours. The former one has large coverage, while the latter one has more accurate embeddings. I tested both, and decided to take the first option. ```python= node_ids = set() for graph in tqdm(dataloader): for node_id in graph.node_ids: node_ids.add(node_id) node_ids = np.array(sorted(node_ids)) size = len(node_ids) _, dim = model(graph.xs, graph.edge_index).shape embeddings = np.zeros((size, dim)) counts = np.zeros((size, 1)) for graph in tqdm(dataloader): batch_node_ids = np.searchsorted(node_ids, graph.node_ids) with torch.no_grad(): xs = model(graph.xs, graph.edge_index).detach().numpy() counts[batch_node_ids] += 1 embeddings[batch_node_ids] += xs artifact = { "node_ids": node_ids, "embeddings": embeddings / counts, } fs = fsspec.filesystem("s3", **s3_options) path = "bluesky/artifact/embeddings.pkl" with fs.open(path, "wb", **s3_options) as fp: fp.write(cloudpickle.dumps(artifact)) print(f"Wrote data to {path}") ``` # User Recommendation Now we have user embeddings. By simply computing the $k$-nearest neighbours ($k$-NNs), we can get the prediction. We use the [ScaNN](https://github.com/google-research/google-research/tree/master/scann) for the embeddable $k$-NN database. # Post Recommendation The post recommendation is a bit tricky. A natural idea is to build a user-post bipartite graph and compute both user and post embeddings using a GNN. However, this approach doesn't work because of the following reason. We are usually interested in recent posts (within a few hours); hence, we are required to compute an embedding of a post within a few hours. The bipartite graph approach cannot satisfy this latency requirement due to the batch fashion of the inference on a larger graph. Here, we define the embedding of a post by the average embeddings of the users who liked the post. We subscribe to the [Firehose API](https://atproto.blue/en/latest/atproto_firehose/index.html) to retrieve all likes in Bluesky, and accumulate the user embeddings on the posts. To reduce the space complexity, we use the [LRU cache](https://en.wikipedia.org/wiki/Cache_replacement_policies#LRU) to maintain the active posts. For faster retrieval, we repeatedly build a ScaNN $k$-NN database in background. In this implementation, we lose the accumulated information if we restart the program. In my observation, it takes about one hour to provide meaningful post recommendation. Hence, we introduce a cache on the stream. Here, we use [Fluvio](https://www.fluvio.io/) for this purpose.