Twitter, Bluesky, Facebook などのオンライン・ソーシャルネットワークを活用するには自分の興味にあったアカウントをフォローすることが大事です.そのために重要な役割を果たすのが「おすすめユーザ推薦 (friend recommendation)」です. ところで最近わたしは趣味で Bluesky のおすすめユーザ推薦モデルを実装して遊んでいました.実装したモデルは https://bluesky-app.raven-sun.ts.net/userrec から遊べます.このモデルはグラフ情報のみを使った(つまりポスト情報などは一切使わない)ものです.ポスト情報なしでもこれくらいの精度が出るのだなあというデモだと思ってください.Bluesky は API が公開されているのでこういう遊びがやりやすいのがよいですね. この記事では上記のモデルの実装手続を紹介します.手続きはデータ準備パート(ソーシャルグラフのクロール)と,機械学習パートからなります.以下それぞれを説明していきます. # データ準備 クローラを実装して Bluesky のソーシャルグラフを取得します.取得したソーシャルグラフは MariaDB 上にテーブル (source_id: VARCHAR, target_id: VARCHAR) として保存します.neo4j などのグラフデータベース上にソーシャルグラフを再現する手もあるのですが,わたしの経験上,RDB にグラフを入れてしまうほうが何かと融通が効いて簡単です. ソーシャルグラフの取得には幅優先探索を用います.時間がかかる処理なので Prefect を用いて定期実行するワークフローとして実装します.MariaDB 上に (user_id, last_update) をカラムとしてもつテーブルを作り,定期的に updated_at が遠いユーザを N 人とってきてそのユーザの follows を取得していきます.Bluesky の API 制限に引っかからないように適宜 sleep しながらクロールします.たまにフォロー数が 100K とかあるユーザがいて困りますが,待てば終わります. 以下がこの部分のコード例です. ```python import asyncio import datetime import random import time from typing import Any, Dict, List import atproto from prefect import flow, get_run_logger, task from prefect.tasks import exponential_backoff from sqlalchemy import Column, create_engine, DateTime, func, String from sqlalchemy.dialects.mysql import insert from sqlalchemy.orm import declarative_base, sessionmaker Base = declarative_base() engine_str = f"mysql+pymysql://root:password@mariadb.example.com/bluesky" class User(Base): __tablename__ = "users" user_id = Column(String(32), nullable=False, primary_key=True) created_at = Column(DateTime(timezone=False), server_default=func.now()) updated_at = Column(DateTime(timezone=False), server_default=func.now()) class Follow(Base): __tablename__ = "follows" source_id = Column(String(32), nullable=False, primary_key=True) target_id = Column(String(32), nullable=False, primary_key=True) created_at = Column(DateTime(timezone=False), server_default=func.now()) updated_at = Column(DateTime(timezone=False), server_default=func.now()) def init_table(): engine = create_engine(engine_str) SessionClass = sessionmaker(engine) session = SessionClass() Base.metadata.create_all(engine) session.add( User(user_id="did:plc:dqxsa5cjfrzulhalom4kuyd2") ) # tmaehara.bsky.social session.commit() @task( tags=["mariadb-session"], retries=3, retry_delay_seconds=exponential_backoff(backoff_factor=60), ) def insert_table_impl( Target: Base, values_list: List[Dict], updated_at: Any = None, ): statement = insert(Target).values(values_list) if updated_at is None: statement = statement.prefix_with("ignore") else: statement = statement.on_duplicate_key_update( updated_at=updated_at or datetime.datetime.now(), ) engine = create_engine(engine_str) SessionClass = sessionmaker(engine) session = SessionClass() session.execute(statement) session.commit() def insert_table( source_id: str, target_ids: List[str], ): follows_values_list = [] target_users_values_list = [] for target_id in target_ids: target_users_values_list.append( { "user_id": target_id, } ) follows_values_list.append( { "source_id": source_id, "target_id": target_id, } ) insert_table_impl(User, target_users_values_list) insert_table_impl( Follow, follows_values_list, updated_at=datetime.datetime.now(), ) @flow def bluesky_social_network(size: int = 32): logger = get_run_logger() engine = create_engine(engine_str) SessionClass = sessionmaker(engine) session = SessionClass() users = session.query(User).order_by(User.updated_at).limit(size).all() user_ids = [user.user_id for user in users] logger.info(user_ids) logger.info(f"minimum updated_at = {users[0].updated_at}") # user_id with updated_at > today is on-process insert_table_impl( User, [{"user_id": user_id} for user_id in user_ids], updated_at=datetime.datetime.now() + datetime.timedelta(days=3650), ) client = atproto.Client() usernames = "username" password = "password" client.login(username, password) for user_id in user_ids: cursor = None num_follows = 0 follows = [] while True: time.sleep(1) # Retrieve until success sleep_time = 10 new_follows = [] for _ in range(12): try: response = client.get_follows( actor=user_id, cursor=cursor, limit=100, ) new_follows = response.follows new_cursor = response.cursor break except Exception as e: logger.error(e) if e.response is not None: # actor not found --- ignore if e.response.status_code == 400: break # rate limit exceed if e.response.status_code == 429: reset_at = e.response.headers["ratelimit-reset"] + 1 reset_at_date = time.strftime( "%Y-%m-%d %H:%M:%S", time.localtime(reset_at) ) duration = reset_at - datetime.now().timestamp() logger.error( f"Sleep until {reset_at_date} ({duration} sec)" ) time.sleep(duration) continue else: logger.error("Unknown error. Re-login") client.login(username, password) time.sleep(sleep_time) sleep_time *= 2 if len(new_follows) == 0: break num_follows += len(new_follows) follows += [profile_view.did for profile_view in new_follows] cursor = new_cursor logger.info(f"Retrieved {num_follows} follows from {user_id}") if len(follows) >= 1000: insert_table(user_id, follows) follows = [] if len(follows) >= 1: insert_table(user_id, follows) insert_table_impl( User, [{"user_id": user_id}], updated_at=datetime.datetime.now(), ) ``` # 学習 集めたデータから「フォローしそうなユーザ」を探すモデルを学習する部分について説明します.様々なやり方がありますが,ここではわたしの好みでグラフニューラルネットワーク (GNN) を使います.GNN はグラフを食べてすべての頂点上の潜在ベクトルを返すモデルなので,フォローしそうなユーザが近い潜在ベクトルをもつように学習します. ちなみに,この定義はちょっと筋が悪いです.というのは,(L2) 距離を使うと A と B が近いとき B と A も近くなるので相互フォローに強い prior がかかります.言い換えると,公式アカウントみたいな「フォローはしないがフォローはされるアカウント」をうまく処理できません.これに対応するには follow 用と followed 用でベクトルを分ければよいのですが,実験してみたところ特に性能が変わらなかったので簡単のためにベクトルを分けずにやっています. ## データ実体化 MariaDB 上に保存されているグラフを PyTorch で簡単に読み込める形に変換します.このような手続を実体化 (materialisation) と呼びます.Bluesky をはじめとするソーシャルグラフは次数に強い偏りをもつので何も考えずに全部読み込むとモデル全体が次数の高い頂点(特に bot)にモデル性能が引きずられてしまいます.これを防ぐため,次数が期待的に閾値以下になるようにサンプリングした部分グラフを利用します.そのようなサンプルを複数個取ることでできるだけ多くの枝をカバーし,サンプルへの過学習を防ぎます.GNN に詳しい人向けに言うと,この手続きは GraphSage のサンプリングをデータセット生成時点で行っていることに相当します. 上述の手続きは SQL で簡単に実装できます.各頂点の次数は何度も使うのであらかじめ計算しておき,サンプリング部分は乱数シードを変えて繰り返し実行します.手元のデータは高々 10M 枝程度なので特別な処理はいりません. 頂点の特徴量を特に集めていないので,ここでは頂点 ID をハッシュにした 32 bitの整数値を特徴量として使います.ID を特徴量にすると未知頂点に対する性能が劣化しがちですが(transductive 的になる),今の場合 Bluesky のグラフでだけうまくいけばよいので特に問題ありません.枝は follow と逆向き枝である follolwed の両方を区別して考慮します.これらをパックしたものを Graph クラスとしてまとめ,pickle して S3 互換ストレージに保存しておきます. ```python import argparse import hashlib from dataclasses import dataclass from typing import Dict import cloudpickle import fsspec import numpy as np import pandas as pd import torch from sqlalchemy import create_engine, text s3_options = { "client_kwargs": {"endpoint_url": "http://minio.example.com:9000"}, "key": "username", "secret": "password", } engine_str = "mysql+pymysql://root:password@mariadb.example.com/bluesky" def hashing(s: str) -> int: md5 = hashlib.md5(s.encode()) return int(md5.hexdigest(), 16) % 2**32 @dataclass class Graph: xs: torch.Tensor edge_index: Dict[str, torch.Tensor] node_ids: np.array @classmethod def from_df(cls, edge_df): node_df = pd.DataFrame( { "node_id": pd.concat( [edge_df["source_id"], edge_df["target_id"]] ).unique(), } ).reset_index() node_ids = node_df["node_id"].values xs = torch.from_numpy(np.vectorize(hashing)(node_ids)) follow = torch.from_numpy( ( edge_df.merge(node_df, left_on="source_id", right_on="node_id").merge( node_df, left_on="target_id", right_on="node_id" ) )[["index_x", "index_y"]].values.T ) edge_index = {"follow": follow, "followed": follow[[1, 0], :]} return cls(xs, edge_index, node_ids) def compute_follows_with_degrees(): engine = create_engine(engine_str) with engine.connect() as conn: conn.execute(text("DROP TABLE IF EXISTS follows_with_degrees")) conn.execute(text(""" CREATE TABLE follows_with_degrees AS ( WITH out_degrees AS ( SELECT source_id, COUNT(1) AS out_degree FROM follows GROUP BY source_id ), in_degrees AS ( SELECT target_id, COUNT(1) AS in_degree FROM follows GROUP BY target_id ) SELECT e.source_id, e.target_id, o.out_degree, i.in_degree FROM bluesky.follows e LEFT JOIN out_degrees o ON e.source_id = o.source_id LEFT JOIN in_degrees i ON e.target_id = i.target_id )""")) def generate_samples( num_samples: int, max_degree: int, ): for i in range(num_samples): engine = create_engine(engine_str) sql = f""" WITH follows_with_rands AS ( SELECT source_id, target_id, out_degree, in_degree, CRC32(CONCAT('{i}', source_id, target_id)) / POW(2, 32) AS r FROM follows_with_degrees ) SELECT source_id, target_id FROM follows_with_rands WHERE r * in_degree <= {max_degree} AND r * out_degree <= {max_degree} """ edge_df = pd.read_sql_query(sql, engine) graph = Graph.from_df(edge_df) fs = fsspec.filesystem("s3", **s3_options) path = f"bluesky/dataset/{i:03}.pkl" with fs.open(path, "wb", **s3_options) as fp: p = cloudpickle.dumps(graph) fp.write(p) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--num_samples", type=int, default=32) parser.add_argument("--max_degree", type=int, default=32) parser.add_argument("--preprocess", type=bool, default=True) args = parser.parse_args() if args.preprocess: compute_follows_with_degrees() generate_samples(args.num_samples, args.max_degree) ``` ## 機械学習 機械学習パートです.PyTorch を使って GNN モデルを実装します.世の中には torch-geometric などの GNN ライブラリもありますが,正直フルスクラッチしてもたいして手間ではないので,わたしはだいたいいつも(業務でも論文でも)フルスクラッチしています. データセット生成の部分で各頂点に $0, .., 2^{32} - 1$ の整数値を特徴量として与えました.GNN を適用するためはこれらを低次元潜在空間(加算が意味をもつ空間)に埋め込む必要があります.PyTorch で整数を潜在空間に埋め込むには `torch.nn.Embedding` を使うのが普通ですが,今の場合値域が大きすぎるので別の手段が必要となります.ここでは `QREmbedding` と呼ばれる,適当な法に対する商 (Quotient) と剰余 (Remainder) を埋め込むトリックを用います.法は値域の平方根にとるのが最も空間効率的です. GNN 部分は次のように実装します.入力グラフがもつ follow, followed という2種類の枝それぞれについて繋がっている先の埋め込みの平均をとります.これによって各頂点は自分自身・follow平均・followed平均という3つのベクトルをもつことになります.これらをスタックした 3d 次元ベクトルを多層パーセプトロンに食わせたものが GNN 埋め込みです.近傍平均は PyTorch の index_add を使うと簡単に実装できます.なお,このモデルは Heterogeneous GIN と呼ばれているものと同値です. 定義した GNN モデルを自己リンク予測によって学習します.タスクは PyTorch Lightning の LightningModule として実装するのが簡単です.与えられたグラフの (source_id, target_id) を正例,target_id をシャッフルしたものを負例として TripletMarginLoss で学習します.これにより source_id が target_id をフォローしているときに source_id の埋め込みと target_id の埋め込みが近くなるようになります. 学習するために S3 互換ストレージから実体化したデータを読み出すことになります.データ読み込みは torchdata の DataPipe/DataLoader2 を使うのが現代的です.S3 互換ストレージからデータを読めそうな DataPipe がいくつか提供されていますが,ここでは FSSpec 系を使うことにします. 学習で得られた埋め込みを S3 互換ストレージに保存して学習プロセスは終了です. ```python import argparse from typing import Dict import cloudpickle import fsspec import lightning as pl import torch from lightning.pytorch.loggers import MLFlowLogger from torchdata.dataloader2 import DataLoader2 from torchdata.datapipes.iter import IterableWrapper s3_options = { "client_kwargs": {"endpoint_url": "http://minio.example.com:9000"}, "key": "username", "secret": "password", } engine_str = "mysql+pymysql://root:password@mariadb.example.com/bluesky" class QREmbedding(torch.nn.Module): def __init__(self, output_dim: int): super().__init__() self.q_embedding = torch.nn.Embedding(2**16, output_dim) self.r_embedding = torch.nn.Embedding(2**16, output_dim) def forward(self, xs: torch.Tensor): q_embedding = self.q_embedding(xs // 2**16) r_embedding = self.r_embedding(xs % 2**16) return q_embedding * r_embedding class Gnn(torch.nn.Module): def __init__(self, hidden_dim): super().__init__() self.nn = torch.nn.Sequential( torch.nn.Linear(3 * hidden_dim, hidden_dim), torch.nn.ReLU(), torch.nn.Linear(hidden_dim, hidden_dim), ) def forward( self, xs: torch.Tensor, edge_index: Dict[str, torch.Tensor] ) -> torch.Tensor: zs = [xs] for edge_type, es in edge_index.items(): zs.append(torch.zeros_like(xs).index_add_(0, es[0], xs[es[1]])) return self.nn(torch.concat(zs, axis=1)) class BlueskyUserModel(torch.nn.Module): def __init__(self, dim): super().__init__() self.embedding = QREmbedding(dim) self.gnn = Gnn(dim) self.accuracy = 0.0 def forward( self, xs: torch.Tensor, edge_index: Dict[str, torch.Tensor], ): xs = self.embedding(xs) xs = self.gnn(xs, edge_index) return xs class LinkPredictionTask(pl.LightningModule): def __init__(self, dim: int): super().__init__() self.model = BlueskyUserModel(dim) self.loss_fn = torch.nn.TripletMarginLoss(1.0) def training_step(self, batch): graph = batch[0] xs = self.model(graph.xs, graph.edge_index) es = graph.edge_index["follow"] pi = torch.randperm(len(es[1])) anchor = xs[es[0]] positive = xs[es[1]] negative = xs[es[1][pi]] loss = self.loss_fn(anchor, positive, negative) accuracy = ( ( (anchor - positive).pow(2).sum(axis=1) < (anchor - negative).pow(2).sum(axis=1) ) .float() .mean() ) self.accuracy = accuracy self.log("loss", loss, prog_bar=True) self.log("accuracy", accuracy, prog_bar=True) return loss def configure_optimizers(self): return torch.optim.Adam(self.parameters()) def decoder(input): return cloudpickle.loads(input[1].read()) def get_dataloader(): dp = ( IterableWrapper(["s3://bluesky/dataset/"]) .list_files_by_fsspec(**s3_options) .open_files_by_fsspec("rb", **s3_options) .map(decoder) ) return DataLoader2(dp) def train( latent_dim: int, max_epochs: int, ): mlf_logger = MLFlowLogger( experiment_name="bluesky-link-prediction", tracking_uri="https://mlflow.example.com/", ) dataloader = get_dataloader() task = LinkPredictionTask(latent_dim) trainer = pl.Trainer(max_epochs=max_epochs, logger=mlf_logger) trainer.fit(task, train_dataloaders=[dataloader]) trained_model = task.model for graph in dataloader: break embeddings = trained_model(graph.xs, graph.edge_index).detach().numpy() artifact = { "node_ids": graph.node_ids, "embeddings": embeddings, } fs = fsspec.filesystem("s3", **s3_options) path = "bluesky/artifact/embeddings.pkl" with fs.open(path, "wb", **s3_options) as fp: fp.write(cloudpickle.dumps(artifact)) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--latent_dim", type=int, default=128) parser.add_argument("--max_epochs", type=int, default=5) args = parser.parse_args() train( latent_dim=args.latent_dim, max_epochs=args.max_epochs, ) ``` ## API / UI FastAPI と uvicorn を使って API サーバを作ります.API サーバは学習で得られた埋め込みを読み込み,ベクトル検索エンジンである ScaNN に入れることで高速検索できるようにしておきます. API サーバと Web サーバを分けるのが面倒だったので FastAPI から JavaScript を直書きした HTML ファイルを表示しています.デザインセンスとフロントエンドの知識がどちらもないのでこんなもんです. これらを Kubernetes で動かし,tailscale の Funnel を使って公開しています. ```python import pickle from typing import List import atproto import numpy as np import pendulum import scann from fastapi import FastAPI, HTTPException, Request from fastapi.responses import FileResponse, HTMLResponse from pydantic import BaseModel app = FastAPI() import fsspec s3_options = { "client_kwargs": {"endpoint_url": "http://minio.example.com:9000"}, "key": "username", "secret": "password", } class RecommendationService: def __init__(self): self.node_ids = None self.embeddings = None self.last_retrieved = pendulum.now().add(years=-1) self.update_model() def update_model(self): current = pendulum.now() if current < self.last_retrieved.add(days=1): return False fs = fsspec.filesystem("s3", **s3_options) path = f"bluesky/artifact/embeddings.pkl" with fs.open(path, "rb", **s3_options) as fp: artifact = pickle.loads(fp.read()) self.node_ids = artifact["node_ids"] self.embeddings = artifact["embeddings"] self.searcher = ( scann.scann_ops_pybind.builder(self.embeddings, 20, "squared_l2") .score_ah(2, anisotropic_quantization_threshold=0.2) .reorder(40) .build() ) self.last_retrieved = current return True def execute(self, node_id: str) -> List[str]: self.update_model() indices = np.where(self.node_ids == node_id)[0] if len(indices) == 0: return [], [] query_index = indices[0] query_vector = self.embeddings[query_index] candidates, distances = self.searcher.search(query_vector) return self.node_ids[candidates].tolist(), distances.tolist() service = RecommendationService() @app.get("/{path}", response_class=HTMLResponse) @app.get("/", response_class=HTMLResponse) def read_html(request: Request, path: str = None): if not path or path in ("index.html", "index.htm"): return FileResponse("index.html") else: raise HTTPException(status_code=404) class ApiRequest(BaseModel): handle: str @app.post("/api") def api(request: ApiRequest): resolver = atproto.IdResolver() query_handle = request.handle.strip() query = resolver.handle.resolve(query_handle) candidates, scores = service.execute(query) handles = [resolver.did.resolve(did).get_handle() for did in candidates] return [ { "handle": handle, "score": score, } for handle, score in zip(handles, scores) ] ```