Twitter, Bluesky, Facebook などのオンライン・ソーシャルネットワークを活用するには自分の興味にあったアカウントをフォローすることが大事です.そのために重要な役割を果たすのが「おすすめユーザ推薦 (friend recommendation)」です.
ところで最近わたしは趣味で Bluesky のおすすめユーザ推薦モデルを実装して遊んでいました.実装したモデルは https://bluesky-app.raven-sun.ts.net/userrec から遊べます.このモデルはグラフ情報のみを使った(つまりポスト情報などは一切使わない)ものです.ポスト情報なしでもこれくらいの精度が出るのだなあというデモだと思ってください.Bluesky は API が公開されているのでこういう遊びがやりやすいのがよいですね.
この記事では上記のモデルの実装手続を紹介します.手続きはデータ準備パート(ソーシャルグラフのクロール)と,機械学習パートからなります.以下それぞれを説明していきます.
# データ準備
クローラを実装して Bluesky のソーシャルグラフを取得します.取得したソーシャルグラフは MariaDB 上にテーブル (source_id: VARCHAR, target_id: VARCHAR) として保存します.neo4j などのグラフデータベース上にソーシャルグラフを再現する手もあるのですが,わたしの経験上,RDB にグラフを入れてしまうほうが何かと融通が効いて簡単です.
ソーシャルグラフの取得には幅優先探索を用います.時間がかかる処理なので Prefect を用いて定期実行するワークフローとして実装します.MariaDB 上に (user_id, last_update) をカラムとしてもつテーブルを作り,定期的に updated_at が遠いユーザを N 人とってきてそのユーザの follows を取得していきます.Bluesky の API 制限に引っかからないように適宜 sleep しながらクロールします.たまにフォロー数が 100K とかあるユーザがいて困りますが,待てば終わります.
以下がこの部分のコード例です.
```python
import asyncio
import datetime
import random
import time
from typing import Any, Dict, List
import atproto
from prefect import flow, get_run_logger, task
from prefect.tasks import exponential_backoff
from sqlalchemy import Column, create_engine, DateTime, func, String
from sqlalchemy.dialects.mysql import insert
from sqlalchemy.orm import declarative_base, sessionmaker
Base = declarative_base()
engine_str = f"mysql+pymysql://root:password@mariadb.example.com/bluesky"
class User(Base):
__tablename__ = "users"
user_id = Column(String(32), nullable=False, primary_key=True)
created_at = Column(DateTime(timezone=False), server_default=func.now())
updated_at = Column(DateTime(timezone=False), server_default=func.now())
class Follow(Base):
__tablename__ = "follows"
source_id = Column(String(32), nullable=False, primary_key=True)
target_id = Column(String(32), nullable=False, primary_key=True)
created_at = Column(DateTime(timezone=False), server_default=func.now())
updated_at = Column(DateTime(timezone=False), server_default=func.now())
def init_table():
engine = create_engine(engine_str)
SessionClass = sessionmaker(engine)
session = SessionClass()
Base.metadata.create_all(engine)
session.add(
User(user_id="did:plc:dqxsa5cjfrzulhalom4kuyd2")
) # tmaehara.bsky.social
session.commit()
@task(
tags=["mariadb-session"],
retries=3,
retry_delay_seconds=exponential_backoff(backoff_factor=60),
)
def insert_table_impl(
Target: Base,
values_list: List[Dict],
updated_at: Any = None,
):
statement = insert(Target).values(values_list)
if updated_at is None:
statement = statement.prefix_with("ignore")
else:
statement = statement.on_duplicate_key_update(
updated_at=updated_at or datetime.datetime.now(),
)
engine = create_engine(engine_str)
SessionClass = sessionmaker(engine)
session = SessionClass()
session.execute(statement)
session.commit()
def insert_table(
source_id: str,
target_ids: List[str],
):
follows_values_list = []
target_users_values_list = []
for target_id in target_ids:
target_users_values_list.append(
{
"user_id": target_id,
}
)
follows_values_list.append(
{
"source_id": source_id,
"target_id": target_id,
}
)
insert_table_impl(User, target_users_values_list)
insert_table_impl(
Follow,
follows_values_list,
updated_at=datetime.datetime.now(),
)
@flow
def bluesky_social_network(size: int = 32):
logger = get_run_logger()
engine = create_engine(engine_str)
SessionClass = sessionmaker(engine)
session = SessionClass()
users = session.query(User).order_by(User.updated_at).limit(size).all()
user_ids = [user.user_id for user in users]
logger.info(user_ids)
logger.info(f"minimum updated_at = {users[0].updated_at}")
# user_id with updated_at > today is on-process
insert_table_impl(
User,
[{"user_id": user_id} for user_id in user_ids],
updated_at=datetime.datetime.now() + datetime.timedelta(days=3650),
)
client = atproto.Client()
usernames = "username"
password = "password"
client.login(username, password)
for user_id in user_ids:
cursor = None
num_follows = 0
follows = []
while True:
time.sleep(1)
# Retrieve until success
sleep_time = 10
new_follows = []
for _ in range(12):
try:
response = client.get_follows(
actor=user_id,
cursor=cursor,
limit=100,
)
new_follows = response.follows
new_cursor = response.cursor
break
except Exception as e:
logger.error(e)
if e.response is not None:
# actor not found --- ignore
if e.response.status_code == 400:
break
# rate limit exceed
if e.response.status_code == 429:
reset_at = e.response.headers["ratelimit-reset"] + 1
reset_at_date = time.strftime(
"%Y-%m-%d %H:%M:%S", time.localtime(reset_at)
)
duration = reset_at - datetime.now().timestamp()
logger.error(
f"Sleep until {reset_at_date} ({duration} sec)"
)
time.sleep(duration)
continue
else:
logger.error("Unknown error. Re-login")
client.login(username, password)
time.sleep(sleep_time)
sleep_time *= 2
if len(new_follows) == 0:
break
num_follows += len(new_follows)
follows += [profile_view.did for profile_view in new_follows]
cursor = new_cursor
logger.info(f"Retrieved {num_follows} follows from {user_id}")
if len(follows) >= 1000:
insert_table(user_id, follows)
follows = []
if len(follows) >= 1:
insert_table(user_id, follows)
insert_table_impl(
User,
[{"user_id": user_id}],
updated_at=datetime.datetime.now(),
)
```
# 学習
集めたデータから「フォローしそうなユーザ」を探すモデルを学習する部分について説明します.様々なやり方がありますが,ここではわたしの好みでグラフニューラルネットワーク (GNN) を使います.GNN はグラフを食べてすべての頂点上の潜在ベクトルを返すモデルなので,フォローしそうなユーザが近い潜在ベクトルをもつように学習します.
ちなみに,この定義はちょっと筋が悪いです.というのは,(L2) 距離を使うと A と B が近いとき B と A も近くなるので相互フォローに強い prior がかかります.言い換えると,公式アカウントみたいな「フォローはしないがフォローはされるアカウント」をうまく処理できません.これに対応するには follow 用と followed 用でベクトルを分ければよいのですが,実験してみたところ特に性能が変わらなかったので簡単のためにベクトルを分けずにやっています.
## データ実体化
MariaDB 上に保存されているグラフを PyTorch で簡単に読み込める形に変換します.このような手続を実体化 (materialisation) と呼びます.Bluesky をはじめとするソーシャルグラフは次数に強い偏りをもつので何も考えずに全部読み込むとモデル全体が次数の高い頂点(特に bot)にモデル性能が引きずられてしまいます.これを防ぐため,次数が期待的に閾値以下になるようにサンプリングした部分グラフを利用します.そのようなサンプルを複数個取ることでできるだけ多くの枝をカバーし,サンプルへの過学習を防ぎます.GNN に詳しい人向けに言うと,この手続きは GraphSage のサンプリングをデータセット生成時点で行っていることに相当します.
上述の手続きは SQL で簡単に実装できます.各頂点の次数は何度も使うのであらかじめ計算しておき,サンプリング部分は乱数シードを変えて繰り返し実行します.手元のデータは高々 10M 枝程度なので特別な処理はいりません.
頂点の特徴量を特に集めていないので,ここでは頂点 ID をハッシュにした 32 bitの整数値を特徴量として使います.ID を特徴量にすると未知頂点に対する性能が劣化しがちですが(transductive 的になる),今の場合 Bluesky のグラフでだけうまくいけばよいので特に問題ありません.枝は follow と逆向き枝である follolwed の両方を区別して考慮します.これらをパックしたものを Graph クラスとしてまとめ,pickle して S3 互換ストレージに保存しておきます.
```python
import argparse
import hashlib
from dataclasses import dataclass
from typing import Dict
import cloudpickle
import fsspec
import numpy as np
import pandas as pd
import torch
from sqlalchemy import create_engine, text
s3_options = {
"client_kwargs": {"endpoint_url": "http://minio.example.com:9000"},
"key": "username",
"secret": "password",
}
engine_str = "mysql+pymysql://root:password@mariadb.example.com/bluesky"
def hashing(s: str) -> int:
md5 = hashlib.md5(s.encode())
return int(md5.hexdigest(), 16) % 2**32
@dataclass
class Graph:
xs: torch.Tensor
edge_index: Dict[str, torch.Tensor]
node_ids: np.array
@classmethod
def from_df(cls, edge_df):
node_df = pd.DataFrame(
{
"node_id": pd.concat(
[edge_df["source_id"], edge_df["target_id"]]
).unique(),
}
).reset_index()
node_ids = node_df["node_id"].values
xs = torch.from_numpy(np.vectorize(hashing)(node_ids))
follow = torch.from_numpy(
(
edge_df.merge(node_df, left_on="source_id", right_on="node_id").merge(
node_df, left_on="target_id", right_on="node_id"
)
)[["index_x", "index_y"]].values.T
)
edge_index = {"follow": follow, "followed": follow[[1, 0], :]}
return cls(xs, edge_index, node_ids)
def compute_follows_with_degrees():
engine = create_engine(engine_str)
with engine.connect() as conn:
conn.execute(text("DROP TABLE IF EXISTS follows_with_degrees"))
conn.execute(text("""
CREATE TABLE follows_with_degrees AS (
WITH out_degrees AS (
SELECT
source_id,
COUNT(1) AS out_degree
FROM follows
GROUP BY source_id
), in_degrees AS (
SELECT
target_id,
COUNT(1) AS in_degree
FROM follows
GROUP BY target_id
)
SELECT
e.source_id,
e.target_id,
o.out_degree,
i.in_degree
FROM bluesky.follows e
LEFT JOIN out_degrees o
ON e.source_id = o.source_id
LEFT JOIN in_degrees i
ON e.target_id = i.target_id
)"""))
def generate_samples(
num_samples: int,
max_degree: int,
):
for i in range(num_samples):
engine = create_engine(engine_str)
sql = f"""
WITH follows_with_rands AS (
SELECT
source_id,
target_id,
out_degree,
in_degree,
CRC32(CONCAT('{i}', source_id, target_id)) / POW(2, 32) AS r
FROM follows_with_degrees
)
SELECT
source_id,
target_id
FROM follows_with_rands
WHERE
r * in_degree <= {max_degree}
AND r * out_degree <= {max_degree}
"""
edge_df = pd.read_sql_query(sql, engine)
graph = Graph.from_df(edge_df)
fs = fsspec.filesystem("s3", **s3_options)
path = f"bluesky/dataset/{i:03}.pkl"
with fs.open(path, "wb", **s3_options) as fp:
p = cloudpickle.dumps(graph)
fp.write(p)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num_samples", type=int, default=32)
parser.add_argument("--max_degree", type=int, default=32)
parser.add_argument("--preprocess", type=bool, default=True)
args = parser.parse_args()
if args.preprocess:
compute_follows_with_degrees()
generate_samples(args.num_samples, args.max_degree)
```
## 機械学習
機械学習パートです.PyTorch を使って GNN モデルを実装します.世の中には torch-geometric などの GNN ライブラリもありますが,正直フルスクラッチしてもたいして手間ではないので,わたしはだいたいいつも(業務でも論文でも)フルスクラッチしています.
データセット生成の部分で各頂点に $0, .., 2^{32} - 1$ の整数値を特徴量として与えました.GNN を適用するためはこれらを低次元潜在空間(加算が意味をもつ空間)に埋め込む必要があります.PyTorch で整数を潜在空間に埋め込むには `torch.nn.Embedding` を使うのが普通ですが,今の場合値域が大きすぎるので別の手段が必要となります.ここでは `QREmbedding` と呼ばれる,適当な法に対する商 (Quotient) と剰余 (Remainder) を埋め込むトリックを用います.法は値域の平方根にとるのが最も空間効率的です.
GNN 部分は次のように実装します.入力グラフがもつ follow, followed という2種類の枝それぞれについて繋がっている先の埋め込みの平均をとります.これによって各頂点は自分自身・follow平均・followed平均という3つのベクトルをもつことになります.これらをスタックした 3d 次元ベクトルを多層パーセプトロンに食わせたものが GNN 埋め込みです.近傍平均は PyTorch の index_add を使うと簡単に実装できます.なお,このモデルは Heterogeneous GIN と呼ばれているものと同値です.
定義した GNN モデルを自己リンク予測によって学習します.タスクは PyTorch Lightning の LightningModule として実装するのが簡単です.与えられたグラフの (source_id, target_id) を正例,target_id をシャッフルしたものを負例として TripletMarginLoss で学習します.これにより source_id が target_id をフォローしているときに source_id の埋め込みと target_id の埋め込みが近くなるようになります.
学習するために S3 互換ストレージから実体化したデータを読み出すことになります.データ読み込みは torchdata の DataPipe/DataLoader2 を使うのが現代的です.S3 互換ストレージからデータを読めそうな DataPipe がいくつか提供されていますが,ここでは FSSpec 系を使うことにします.
学習で得られた埋め込みを S3 互換ストレージに保存して学習プロセスは終了です.
```python
import argparse
from typing import Dict
import cloudpickle
import fsspec
import lightning as pl
import torch
from lightning.pytorch.loggers import MLFlowLogger
from torchdata.dataloader2 import DataLoader2
from torchdata.datapipes.iter import IterableWrapper
s3_options = {
"client_kwargs": {"endpoint_url": "http://minio.example.com:9000"},
"key": "username",
"secret": "password",
}
engine_str = "mysql+pymysql://root:password@mariadb.example.com/bluesky"
class QREmbedding(torch.nn.Module):
def __init__(self, output_dim: int):
super().__init__()
self.q_embedding = torch.nn.Embedding(2**16, output_dim)
self.r_embedding = torch.nn.Embedding(2**16, output_dim)
def forward(self, xs: torch.Tensor):
q_embedding = self.q_embedding(xs // 2**16)
r_embedding = self.r_embedding(xs % 2**16)
return q_embedding * r_embedding
class Gnn(torch.nn.Module):
def __init__(self, hidden_dim):
super().__init__()
self.nn = torch.nn.Sequential(
torch.nn.Linear(3 * hidden_dim, hidden_dim),
torch.nn.ReLU(),
torch.nn.Linear(hidden_dim, hidden_dim),
)
def forward(
self, xs: torch.Tensor, edge_index: Dict[str, torch.Tensor]
) -> torch.Tensor:
zs = [xs]
for edge_type, es in edge_index.items():
zs.append(torch.zeros_like(xs).index_add_(0, es[0], xs[es[1]]))
return self.nn(torch.concat(zs, axis=1))
class BlueskyUserModel(torch.nn.Module):
def __init__(self, dim):
super().__init__()
self.embedding = QREmbedding(dim)
self.gnn = Gnn(dim)
self.accuracy = 0.0
def forward(
self,
xs: torch.Tensor,
edge_index: Dict[str, torch.Tensor],
):
xs = self.embedding(xs)
xs = self.gnn(xs, edge_index)
return xs
class LinkPredictionTask(pl.LightningModule):
def __init__(self, dim: int):
super().__init__()
self.model = BlueskyUserModel(dim)
self.loss_fn = torch.nn.TripletMarginLoss(1.0)
def training_step(self, batch):
graph = batch[0]
xs = self.model(graph.xs, graph.edge_index)
es = graph.edge_index["follow"]
pi = torch.randperm(len(es[1]))
anchor = xs[es[0]]
positive = xs[es[1]]
negative = xs[es[1][pi]]
loss = self.loss_fn(anchor, positive, negative)
accuracy = (
(
(anchor - positive).pow(2).sum(axis=1)
< (anchor - negative).pow(2).sum(axis=1)
)
.float()
.mean()
)
self.accuracy = accuracy
self.log("loss", loss, prog_bar=True)
self.log("accuracy", accuracy, prog_bar=True)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters())
def decoder(input):
return cloudpickle.loads(input[1].read())
def get_dataloader():
dp = (
IterableWrapper(["s3://bluesky/dataset/"])
.list_files_by_fsspec(**s3_options)
.open_files_by_fsspec("rb", **s3_options)
.map(decoder)
)
return DataLoader2(dp)
def train(
latent_dim: int,
max_epochs: int,
):
mlf_logger = MLFlowLogger(
experiment_name="bluesky-link-prediction",
tracking_uri="https://mlflow.example.com/",
)
dataloader = get_dataloader()
task = LinkPredictionTask(latent_dim)
trainer = pl.Trainer(max_epochs=max_epochs, logger=mlf_logger)
trainer.fit(task, train_dataloaders=[dataloader])
trained_model = task.model
for graph in dataloader:
break
embeddings = trained_model(graph.xs, graph.edge_index).detach().numpy()
artifact = {
"node_ids": graph.node_ids,
"embeddings": embeddings,
}
fs = fsspec.filesystem("s3", **s3_options)
path = "bluesky/artifact/embeddings.pkl"
with fs.open(path, "wb", **s3_options) as fp:
fp.write(cloudpickle.dumps(artifact))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--latent_dim", type=int, default=128)
parser.add_argument("--max_epochs", type=int, default=5)
args = parser.parse_args()
train(
latent_dim=args.latent_dim,
max_epochs=args.max_epochs,
)
```
## API / UI
FastAPI と uvicorn を使って API サーバを作ります.API サーバは学習で得られた埋め込みを読み込み,ベクトル検索エンジンである ScaNN に入れることで高速検索できるようにしておきます.
API サーバと Web サーバを分けるのが面倒だったので FastAPI から JavaScript を直書きした HTML ファイルを表示しています.デザインセンスとフロントエンドの知識がどちらもないのでこんなもんです.
これらを Kubernetes で動かし,tailscale の Funnel を使って公開しています.
```python
import pickle
from typing import List
import atproto
import numpy as np
import pendulum
import scann
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import FileResponse, HTMLResponse
from pydantic import BaseModel
app = FastAPI()
import fsspec
s3_options = {
"client_kwargs": {"endpoint_url": "http://minio.example.com:9000"},
"key": "username",
"secret": "password",
}
class RecommendationService:
def __init__(self):
self.node_ids = None
self.embeddings = None
self.last_retrieved = pendulum.now().add(years=-1)
self.update_model()
def update_model(self):
current = pendulum.now()
if current < self.last_retrieved.add(days=1):
return False
fs = fsspec.filesystem("s3", **s3_options)
path = f"bluesky/artifact/embeddings.pkl"
with fs.open(path, "rb", **s3_options) as fp:
artifact = pickle.loads(fp.read())
self.node_ids = artifact["node_ids"]
self.embeddings = artifact["embeddings"]
self.searcher = (
scann.scann_ops_pybind.builder(self.embeddings, 20, "squared_l2")
.score_ah(2, anisotropic_quantization_threshold=0.2)
.reorder(40)
.build()
)
self.last_retrieved = current
return True
def execute(self, node_id: str) -> List[str]:
self.update_model()
indices = np.where(self.node_ids == node_id)[0]
if len(indices) == 0:
return [], []
query_index = indices[0]
query_vector = self.embeddings[query_index]
candidates, distances = self.searcher.search(query_vector)
return self.node_ids[candidates].tolist(), distances.tolist()
service = RecommendationService()
@app.get("/{path}", response_class=HTMLResponse)
@app.get("/", response_class=HTMLResponse)
def read_html(request: Request, path: str = None):
if not path or path in ("index.html", "index.htm"):
return FileResponse("index.html")
else:
raise HTTPException(status_code=404)
class ApiRequest(BaseModel):
handle: str
@app.post("/api")
def api(request: ApiRequest):
resolver = atproto.IdResolver()
query_handle = request.handle.strip()
query = resolver.handle.resolve(query_handle)
candidates, scores = service.execute(query)
handles = [resolver.did.resolve(did).get_handle() for did in candidates]
return [
{
"handle": handle,
"score": score,
}
for handle, score in zip(handles, scores)
]
```