diff --git a/api/atproto/cbor_gen.go b/api/atproto/cbor_gen.go index 65956036d..dec76f18e 100644 --- a/api/atproto/cbor_gen.go +++ b/api/atproto/cbor_gen.go @@ -935,6 +935,277 @@ func (t *SyncSubscribeRepos_Commit) UnmarshalCBOR(r io.Reader) (err error) { return nil } +func (t *SyncSubscribeRepos_Sync) MarshalCBOR(w io.Writer) error { + if t == nil { + _, err := w.Write(cbg.CborNull) + return err + } + + cw := cbg.NewCborWriter(w) + fieldCount := 5 + + if t.Blocks == nil { + fieldCount-- + } + + if _, err := cw.Write(cbg.CborEncodeMajorType(cbg.MajMap, uint64(fieldCount))); err != nil { + return err + } + + // t.Did (string) (string) + if len("did") > 1000000 { + return xerrors.Errorf("Value in field \"did\" was too long") + } + + if err := cw.WriteMajorTypeHeader(cbg.MajTextString, uint64(len("did"))); err != nil { + return err + } + if _, err := cw.WriteString(string("did")); err != nil { + return err + } + + if len(t.Did) > 1000000 { + return xerrors.Errorf("Value in field t.Did was too long") + } + + if err := cw.WriteMajorTypeHeader(cbg.MajTextString, uint64(len(t.Did))); err != nil { + return err + } + if _, err := cw.WriteString(string(t.Did)); err != nil { + return err + } + + // t.Rev (string) (string) + if len("rev") > 1000000 { + return xerrors.Errorf("Value in field \"rev\" was too long") + } + + if err := cw.WriteMajorTypeHeader(cbg.MajTextString, uint64(len("rev"))); err != nil { + return err + } + if _, err := cw.WriteString(string("rev")); err != nil { + return err + } + + if len(t.Rev) > 1000000 { + return xerrors.Errorf("Value in field t.Rev was too long") + } + + if err := cw.WriteMajorTypeHeader(cbg.MajTextString, uint64(len(t.Rev))); err != nil { + return err + } + if _, err := cw.WriteString(string(t.Rev)); err != nil { + return err + } + + // t.Seq (int64) (int64) + if len("seq") > 1000000 { + return xerrors.Errorf("Value in field \"seq\" was too long") + } + + if err := cw.WriteMajorTypeHeader(cbg.MajTextString, uint64(len("seq"))); err != nil { + return err + } + if _, err := cw.WriteString(string("seq")); err != nil { + return err + } + + if t.Seq >= 0 { + if err := cw.WriteMajorTypeHeader(cbg.MajUnsignedInt, uint64(t.Seq)); err != nil { + return err + } + } else { + if err := cw.WriteMajorTypeHeader(cbg.MajNegativeInt, uint64(-t.Seq-1)); err != nil { + return err + } + } + + // t.Time (string) (string) + if len("time") > 1000000 { + return xerrors.Errorf("Value in field \"time\" was too long") + } + + if err := cw.WriteMajorTypeHeader(cbg.MajTextString, uint64(len("time"))); err != nil { + return err + } + if _, err := cw.WriteString(string("time")); err != nil { + return err + } + + if len(t.Time) > 1000000 { + return xerrors.Errorf("Value in field t.Time was too long") + } + + if err := cw.WriteMajorTypeHeader(cbg.MajTextString, uint64(len(t.Time))); err != nil { + return err + } + if _, err := cw.WriteString(string(t.Time)); err != nil { + return err + } + + // t.Blocks (util.LexBytes) (slice) + if t.Blocks != nil { + + if len("blocks") > 1000000 { + return xerrors.Errorf("Value in field \"blocks\" was too long") + } + + if err := cw.WriteMajorTypeHeader(cbg.MajTextString, uint64(len("blocks"))); err != nil { + return err + } + if _, err := cw.WriteString(string("blocks")); err != nil { + return err + } + + if len(t.Blocks) > 2097152 { + return xerrors.Errorf("Byte array in field t.Blocks was too long") + } + + if err := cw.WriteMajorTypeHeader(cbg.MajByteString, uint64(len(t.Blocks))); err != nil { + return err + } + + if _, err := cw.Write(t.Blocks); err != nil { + return err + } + + } + return nil +} + +func (t *SyncSubscribeRepos_Sync) UnmarshalCBOR(r io.Reader) (err error) { + *t = SyncSubscribeRepos_Sync{} + + cr := cbg.NewCborReader(r) + + maj, extra, err := cr.ReadHeader() + if err != nil { + return err + } + defer func() { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + }() + + if maj != cbg.MajMap { + return fmt.Errorf("cbor input should be of type map") + } + + if extra > cbg.MaxLength { + return fmt.Errorf("SyncSubscribeRepos_Sync: map struct too large (%d)", extra) + } + + n := extra + + nameBuf := make([]byte, 6) + for i := uint64(0); i < n; i++ { + nameLen, ok, err := cbg.ReadFullStringIntoBuf(cr, nameBuf, 1000000) + if err != nil { + return err + } + + if !ok { + // Field doesn't exist on this type, so ignore it + if err := cbg.ScanForLinks(cr, func(cid.Cid) {}); err != nil { + return err + } + continue + } + + switch string(nameBuf[:nameLen]) { + // t.Did (string) (string) + case "did": + + { + sval, err := cbg.ReadStringWithMax(cr, 1000000) + if err != nil { + return err + } + + t.Did = string(sval) + } + // t.Rev (string) (string) + case "rev": + + { + sval, err := cbg.ReadStringWithMax(cr, 1000000) + if err != nil { + return err + } + + t.Rev = string(sval) + } + // t.Seq (int64) (int64) + case "seq": + { + maj, extra, err := cr.ReadHeader() + if err != nil { + return err + } + var extraI int64 + switch maj { + case cbg.MajUnsignedInt: + extraI = int64(extra) + if extraI < 0 { + return fmt.Errorf("int64 positive overflow") + } + case cbg.MajNegativeInt: + extraI = int64(extra) + if extraI < 0 { + return fmt.Errorf("int64 negative overflow") + } + extraI = -1 - extraI + default: + return fmt.Errorf("wrong type for int64 field: %d", maj) + } + + t.Seq = int64(extraI) + } + // t.Time (string) (string) + case "time": + + { + sval, err := cbg.ReadStringWithMax(cr, 1000000) + if err != nil { + return err + } + + t.Time = string(sval) + } + // t.Blocks (util.LexBytes) (slice) + case "blocks": + + maj, extra, err = cr.ReadHeader() + if err != nil { + return err + } + + if extra > 2097152 { + return fmt.Errorf("t.Blocks: byte array too large (%d)", extra) + } + if maj != cbg.MajByteString { + return fmt.Errorf("expected byte array") + } + + if extra > 0 { + t.Blocks = make([]uint8, extra) + } + + if _, err := io.ReadFull(cr, t.Blocks); err != nil { + return err + } + + default: + // Field doesn't exist on this type, so ignore it + if err := cbg.ScanForLinks(r, func(cid.Cid) {}); err != nil { + return err + } + } + } + + return nil +} func (t *SyncSubscribeRepos_Handle) MarshalCBOR(w io.Writer) error { if t == nil { _, err := w.Write(cbg.CborNull) diff --git a/atproto/identity/cache_directory.go b/atproto/identity/cache_directory.go index 0ecc6c292..751041ca4 100644 --- a/atproto/identity/cache_directory.go +++ b/atproto/identity/cache_directory.go @@ -11,6 +11,7 @@ import ( "github.com/hashicorp/golang-lru/v2/expirable" ) +// CacheDirectory is an implementation of identity.Directory with local cache of Handle and DID type CacheDirectory struct { Inner Directory ErrTTL time.Duration diff --git a/bgs/bgs.go b/bgs/bgs.go index 5f495b7fd..3bcacf1e6 100644 --- a/bgs/bgs.go +++ b/bgs/bgs.go @@ -178,6 +178,7 @@ func NewBGS(db *gorm.DB, ix *indexer.Indexer, repoman *repomgr.RepoManager, evtm slOpts.DefaultRepoLimit = config.DefaultRepoLimit slOpts.ConcurrencyPerPDS = config.ConcurrencyPerPDS slOpts.MaxQueuePerPDS = config.MaxQueuePerPDS + slOpts.Logger = bgs.log s, err := NewSlurper(db, bgs.handleFedEvent, slOpts) if err != nil { return nil, err diff --git a/bgs/fedmgr.go b/bgs/fedmgr.go index 710c15cb0..2f8ed2efc 100644 --- a/bgs/fedmgr.go +++ b/bgs/fedmgr.go @@ -56,6 +56,8 @@ type Slurper struct { shutdownResult chan []error ssl bool + + log *slog.Logger } type Limiters struct { @@ -73,6 +75,7 @@ type SlurperOptions struct { DefaultRepoLimit int64 ConcurrencyPerPDS int64 MaxQueuePerPDS int64 + Logger *slog.Logger } func DefaultSlurperOptions() *SlurperOptions { @@ -85,6 +88,7 @@ func DefaultSlurperOptions() *SlurperOptions { DefaultRepoLimit: 100, ConcurrencyPerPDS: 100, MaxQueuePerPDS: 1_000, + Logger: slog.Default(), } } @@ -115,6 +119,7 @@ func NewSlurper(db *gorm.DB, cb IndexCallback, opts *SlurperOptions) (*Slurper, ssl: opts.SSL, shutdownChan: make(chan bool), shutdownResult: make(chan []error), + log: opts.Logger, } if err := s.loadConfig(); err != nil { return nil, err diff --git a/cmd/bigsky/main.go b/cmd/bigsky/main.go index 08d638c3c..011d37539 100644 --- a/cmd/bigsky/main.go +++ b/cmd/bigsky/main.go @@ -310,7 +310,7 @@ func runBigsky(cctx *cli.Context) error { signals := make(chan os.Signal, 1) signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM) - _, err := cliutil.SetupSlog(cliutil.LogOptions{}) + _, _, err := cliutil.SetupSlog(cliutil.LogOptions{}) if err != nil { return err } diff --git a/cmd/goat/firehose.go b/cmd/goat/firehose.go index eceda9948..f3590af73 100644 --- a/cmd/goat/firehose.go +++ b/cmd/goat/firehose.go @@ -79,7 +79,6 @@ func runFirehose(cctx *cli.Context) error { } relayHost := cctx.String("relay-host") - cursor := cctx.Int("cursor") dialer := websocket.DefaultDialer u, err := url.Parse(relayHost) @@ -87,10 +86,12 @@ func runFirehose(cctx *cli.Context) error { return fmt.Errorf("invalid relayHost URI: %w", err) } u.Path = "xrpc/com.atproto.sync.subscribeRepos" - if cursor != 0 { - u.RawQuery = fmt.Sprintf("cursor=%d", cursor) + if cctx.IsSet("cursor") { + u.RawQuery = fmt.Sprintf("cursor=%d", cctx.Int("cursor")) } - con, _, err := dialer.Dial(u.String(), http.Header{ + urlString := u.String() + slog.Debug("GET", "url", urlString) + con, _, err := dialer.Dial(urlString, http.Header{ "User-Agent": []string{fmt.Sprintf("goat/%s", versioninfo.Short())}, }) if err != nil { diff --git a/cmd/gosky/main.go b/cmd/gosky/main.go index 28b8602be..249b6b5f0 100644 --- a/cmd/gosky/main.go +++ b/cmd/gosky/main.go @@ -81,7 +81,7 @@ func run(args []string) { }, } - _, err := cliutil.SetupSlog(cliutil.LogOptions{}) + _, _, err := cliutil.SetupSlog(cliutil.LogOptions{}) if err != nil { fmt.Fprintf(os.Stderr, "logging setup error: %s\n", err.Error()) os.Exit(1) diff --git a/cmd/relay/Dockerfile b/cmd/relay/Dockerfile new file mode 100644 index 000000000..f7153638b --- /dev/null +++ b/cmd/relay/Dockerfile @@ -0,0 +1,49 @@ +# Run this dockerfile from the top level of the indigo git repository like: +# +# podman build -f ./cmd/relay/Dockerfile -t relay . + +### Compile stage +FROM golang:1.23-alpine3.20 AS build-env +RUN apk add --no-cache build-base make git + +ADD . /dockerbuild +WORKDIR /dockerbuild + +# timezone data for alpine builds +ENV GOEXPERIMENT=loopvar +RUN GIT_VERSION=$(git describe --tags --long --always) && \ + go build -tags timetzdata -o /relay ./cmd/relay + +### Build Frontend stage +FROM node:18-alpine as web-builder + +WORKDIR /app + +COPY ts/bgs-dash /app/ + +RUN yarn install --frozen-lockfile + +RUN yarn build + +### Run stage +FROM alpine:3.20 + +RUN apk add --no-cache --update dumb-init ca-certificates runit +ENTRYPOINT ["dumb-init", "--"] + +WORKDIR / +RUN mkdir -p data/relay +COPY --from=build-env /relay / +COPY --from=web-builder /app/dist/ public/ + +# small things to make golang binaries work well under alpine +ENV GODEBUG=netdns=go +ENV TZ=Etc/UTC + +EXPOSE 2470 + +CMD ["/relay"] + +LABEL org.opencontainers.image.source=https://github.com/bluesky-social/indigo +LABEL org.opencontainers.image.description="atproto Relay" +LABEL org.opencontainers.image.licenses=MIT diff --git a/cmd/relay/README.md b/cmd/relay/README.md new file mode 100644 index 000000000..c049e6880 --- /dev/null +++ b/cmd/relay/README.md @@ -0,0 +1,339 @@ + +atproto Relay Service +=============================== + +*NOTE: "Relays" used to be called "Big Graph Servers", or "BGS", or "bigsky". Many variables and packages still reference "bgs"* + +This is the implementation of an atproto Relay which is running in the production network, written and operated by Bluesky. + +In atproto, a Relay subscribes to multiple PDS hosts and outputs a combined "firehose" event stream. Downstream services can subscribe to this single firehose a get all relevant events for the entire network, or a specific sub-graph of the network. The Relay maintains a mirror of repo data from all accounts on the upstream PDS instances, and verifies repo data structure integrity and identity signatures. It is agnostic to applications, and does not validate data against atproto Lexicon schemas. + +This Relay implementation is designed to subscribe to the entire global network. The current state of the codebase is informally expected to scale to around 20 million accounts in the network, and thousands of repo events per second (peak). + +Features and design decisions: + +- runs on a single server +- repo data: stored on-disk in individual CAR "slice" files, with metadata in SQL. filesystem must accommodate tens of millions of small files +- firehose backfill data: stored on-disk by default, with metadata in SQL +- crawling and account state: stored in SQL database +- SQL driver: gorm, with PostgreSQL in production and sqlite for testing +- disk I/O intensive: fast NVMe disks are recommended, and RAM is helpful for caching +- highly concurrent: not particularly CPU intensive +- single golang binary for easy deployment +- observability: logging, prometheus metrics, OTEL traces +- "spidering" feature to auto-discover new accounts (DIDs) +- ability to export/import lists of DIDs to "backfill" Relay instances +- periodic repo compaction +- admin web interface: configure limits, add upstream PDS instances, etc + +This software is not as packaged, documented, and supported for self-hosting as our PDS distribution or Ozone service. But it is relatively simple and inexpensive to get running. + +A note and reminder about Relays in general are that they are more of a convenience in the protocol than a hard requirement. The "firehose" API is the exact same on the PDS and on a Relay. Any service which subscribes to the Relay could instead connect to one or more PDS instances directly. + + +## Development Tips + +The README and Makefile at the top level of this git repo have some generic helpers for testing, linting, formatting code, etc. + +To re-build and run the Relay locally: + + make run-dev-relay + +You can re-build and run the command directly to get a list of configuration flags and env vars; env vars will be loaded from `.env` if that file exists: + + RELAY_ADMIN_KEY=localdev go run ./cmd/relay/ --help + +By default, the daemon will use sqlite for databases (in the directory `./data/bigsky/`), CAR data will be stored as individual shard files in `./data/bigsky/carstore/`), and the HTTP API will be bound to localhost port 2470. + +When the daemon isn't running, sqlite database files can be inspected with: + + sqlite3 data/bigsky/bgs.sqlite + [...] + sqlite> .schema + +Wipe all local data: + + # careful! double-check this destructive command + rm -rf ./data/bigsky/* + +There is a basic web dashboard, though it will not be included unless built and copied to a local directory `./public/`. Run `make build-relay-ui`, and then when running the daemon the dashboard will be available at: . Paste in the admin key, eg `localdev`. + +The local admin routes can also be accessed by passing the admin key as a bearer token, for example: + + http get :2470/admin/pds/list Authorization:"Bearer localdev" + +Request crawl of an individual PDS instance like: + + http post :2470/admin/pds/requestCrawl Authorization:"Bearer localdev" hostname=pds.example.com + + +## Docker Containers + +One way to deploy is running a docker image. You can pull and/or run a specific version of bigsky, referenced by git commit, from the Bluesky Github container registry. For example: + + docker pull ghcr.io/bluesky-social/indigo:bigsky-fd66f93ce1412a3678a1dd3e6d53320b725978a6 + docker run ghcr.io/bluesky-social/indigo:bigsky-fd66f93ce1412a3678a1dd3e6d53320b725978a6 + +There is a Dockerfile in this directory, which can be used to build customized/patched versions of the Relay as a container, republish them, run locally, deploy to servers, deploy to an orchestrated cluster, etc. See docs and guides for docker and cluster management systems for details. + + +## Database Setup + +PostgreSQL and Sqlite are both supported. When using Sqlite, separate files are used for Relay metadata and CarStore metadata. With PostgreSQL a single database server, user, and logical database can all be reused: table names will not conflict. + +Database configuration is passed via the `DATABASE_URL` and `CARSTORE_DATABASE_URL` environment variables, or the corresponding CLI args. + +For PostgreSQL, the user and database must already be configured. Some example SQL commands are: + + CREATE DATABASE bgs; + CREATE DATABASE carstore; + + CREATE USER ${username} WITH PASSWORD '${password}'; + GRANT ALL PRIVILEGES ON DATABASE bgs TO ${username}; + GRANT ALL PRIVILEGES ON DATABASE carstore TO ${username}; + +This service currently uses `gorm` to automatically run database migrations as the regular user. There is no concept of running a separate set of migrations under more privileged database user. + + +## Deployment + +*NOTE: this is not a complete guide to operating a Relay. There are decisions to be made and communicated about policies, bandwidth use, PDS crawling and rate-limits, financial sustainability, etc, which are not covered here. This is just a quick overview of how to technically get a relay up and running.* + +In a real-world system, you will probably want to use PostgreSQL for both the relay database and the carstore database. CAR shards will still be stored on-disk, resulting in many millions of files. Chose your storage hardware and filesystem carefully: we recommend XFS on local NVMe, not network-backed blockstorage (eg, not EBS volumes on AWS). + +Some notable configuration env vars to set: + +- `ENVIRONMENT`: eg, `production` +- `DATABASE_URL`: see section below +- `CARSTORE_DATABASE_URL`: see section below +- `DATA_DIR`: CAR shards will be stored in a subdirectory +- `GOLOG_LOG_LEVEL`: log verbosity +- `RESOLVE_ADDRESS`: DNS server to use +- `FORCE_DNS_UDP`: recommend "true" +- `BGS_COMPACT_INTERVAL`: to control CAR compaction scheduling. for example, "8h" (every 8 hours). Set to "0" to disable automatic compaction. +- `MAX_CARSTORE_CONNECTIONS` and `MAX_METADB_CONNECTIONS`: number of concurrent SQL database connections +- `MAX_FETCH_CONCURRENCY`: how many outbound CAR backfill requests to make in parallel + +There is a health check endpoint at `/xrpc/_health`. Prometheus metrics are exposed by default on port 2471, path `/metrics`. The service logs fairly verbosely to stderr; use `GOLOG_LOG_LEVEL` to control log volume. + +As a rough guideline for the compute resources needed to run a full-network Relay, in June 2024 an example Relay for over 5 million repositories used: + +- around 30 million inodes (files) +- roughly 1 TByte of disk for PostgreSQL +- roughly 1 TByte of disk for CAR shard storage +- roughly 5k disk I/O operations per second (all combined) +- roughly 100% of one CPU core (quite low CPU utilization) +- roughly 5GB of RAM for bigsky, and as much RAM as available for PostgreSQL and page cache +- on the order of 1 megabit inbound bandwidth (crawling PDS instances) and 1 megabit outbound per connected client. 1 mbit continuous is approximately 350 GByte/month + +Be sure to double-check bandwidth usage and pricing if running a public relay! Bandwidth prices can vary widely between providers, and popular cloud services (AWS, Google Cloud, Azure) are very expensive compared to alternatives like OVH or Hetzner. + + +## Bootstrapping the Network + +To bootstrap the entire network, you'll want to start with a list of large PDS instances to backfill from. You could pull from a public dashboard of instances (like [mackuba's](https://blue.mackuba.eu/directory/pdses)), or scrape the full DID PLC directory, parse out all PDS service declarations, and sort by count. + +Once you have a set of PDS hosts, you can put the bare hostnames (not URLs: no `https://` prefix, port, or path suffix) in a `hosts.txt` file, and then use the `crawl_pds.sh` script to backfill and configure limits for all of them: + + export RELAY_HOST=your.pds.hostname.tld + export RELAY_ADMIN_KEY=your-secret-key + + # both request crawl, and set generous crawl limits for each + cat hosts.txt | parallel -j1 ./crawl_pds.sh {} + +Just consuming from the firehose for a few hours will only backfill accounts with activity during that period. This is fine to get the backfill process started, but eventually you'll want to do full "resync" of all the repositories on the PDS host to the most recent repo rev version. To enqueue that for all the PDS instances: + + # start sync/backfill of all accounts + cat hosts.txt | parallel -j1 ./sync_pds.sh {} + +Lastly, can monitor progress of any ongoing re-syncs: + + # check sync progress for all hosts + cat hosts.txt | parallel -j1 ./sync_pds.sh {} + + +## Admin API + +The relay has a number of admin HTTP API endpoints. Given a relay setup listening on port 2470 and with a reasonably secure admin secret: + +``` +RELAY_ADMIN_PASSWORD=$(openssl rand --hex 16) +relay --api-listen :2470 --admin-key ${RELAY_ADMIN_PASSWORD} ... +``` + +One can, for example, begin compaction of all repos + +``` +curl -H 'Authorization: Bearer '${RELAY_ADMIN_PASSWORD} -H 'Content-Type: application/x-www-form-urlencoded' --data '' http://127.0.0.1:2470/admin/repo/compactAll +``` + +### /admin/subs/getUpstreamConns + +Return list of PDS host names in json array of strings: ["host", ...] + +### /admin/subs/perDayLimit + +Return `{"limit": int}` for the number of new PDS subscriptions that the relay may start in a rolling 24 hour window. + +### /admin/subs/setPerDayLimit + +POST with `?limit={int}` to set the number of new PDS subscriptions that the relay may start in a rolling 24 hour window. + +### /admin/subs/setEnabled + +POST with param `?enabled=true` or `?enabled=false` to enable or disable PDS-requested new-PDS crawling. + +### /admin/subs/getEnabled + +Return `{"enabled": bool}` if non-admin new PDS crawl requests are enabled + +### /admin/subs/killUpstream + +POST with `?host={pds host name}` to disconnect from their firehose. + +Optionally add `&block=true` to prevent connecting to them in the future. + +### /admin/subs/listDomainBans + +Return `{"banned_domains": ["host name", ...]}` + +### /admin/subs/banDomain + +POST `{"Domain": "host name"}` to ban a domain + +### /admin/subs/unbanDomain + +POST `{"Domain": "host name"}` to un-ban a domain + +### /admin/repo/takeDown + +POST `{"did": "did:..."}` to take-down a bad repo; deletes all local data for the repo + +### /admin/repo/reverseTakedown + +POST `?did={did:...}` to reverse a repo take-down + +### /admin/repo/compact + +POST `?did={did:...}` to compact a repo. Optionally `&fast=true`. HTTP blocks until the compaction finishes. + +### /admin/repo/compactAll + +POST to begin compaction of all repos. Optional query params: + + * `fast=true` + * `limit={int}` maximum number of repos to compact (biggest first) (default 50) + * `threhsold={int}` minimum number of shard files a repo must have on disk to merit compaction (default 20) + +### /admin/repo/reset + +POST `?did={did:...}` deletes all local data for the repo + +### /admin/repo/verify + +POST `?did={did:...}` checks that all repo data is accessible. HTTP blocks until done. + +### /admin/pds/requestCrawl + +POST `{"hostname":"pds host"}` to start crawling a PDS + +### /admin/pds/list + +GET returns JSON list of records +```json +[{ + "Host": string, + "Did": string, + "SSL": bool, + "Cursor": int, + "Registered": bool, + "Blocked": bool, + "RateLimit": float, + "CrawlRateLimit": float, + "RepoCount": int, + "RepoLimit": int, + "HourlyEventLimit": int, + "DailyEventLimit": int, + + "HasActiveConnection": bool, + "EventsSeenSinceStartup": int, + "PerSecondEventRate": {"Max": float, "Window": float seconds}, + "PerHourEventRate": {"Max": float, "Window": float seconds}, + "PerDayEventRate": {"Max": float, "Window": float seconds}, + "CrawlRate": {"Max": float, "Window": float seconds}, + "UserCount": int, +}, ...] +``` + +### /admin/pds/resync + +POST `?host={host}` to start a resync of a PDS + +GET `?host={host}` to get status of a PDS resync, return + +```json +{"resync": { + "pds": { + "Host": string, + "Did": string, + "SSL": bool, + "Cursor": int, + "Registered": bool, + "Blocked": bool, + "RateLimit": float, + "CrawlRateLimit": float, + "RepoCount": int, + "RepoLimit": int, + "HourlyEventLimit": int, + "DailyEventLimit": int, + }, + "numRepoPages": int, + "numRepos": int, + "numReposChecked": int, + "numReposToResync": int, + "status": string, + "statusChangedAt": time, +}} +``` + +### /admin/pds/changeLimits + +POST to set the limits for a PDS. body: + +```json +{ + "host": string, + "per_second": int, + "per_hour": int, + "per_day": int, + "crawl_rate": int, + "repo_limit": int, +} +``` + +### /admin/pds/block + +POST `?host={host}` to block a PDS + +### /admin/pds/unblock + +POST `?host={host}` to un-block a PDS + + +### /admin/pds/addTrustedDomain + +POST `?domain={}` to make a domain trusted + +### /admin/consumers/list + +GET returns list json of clients currently reading from the relay firehose + +```json +[{ + "id": int, + "remote_addr": string, + "user_agent": string, + "events_consumed": int, + "connected_at": time, +}, ...] +``` diff --git a/cmd/relay/bgs/admin.go b/cmd/relay/bgs/admin.go new file mode 100644 index 000000000..068d181c4 --- /dev/null +++ b/cmd/relay/bgs/admin.go @@ -0,0 +1,539 @@ +package bgs + +import ( + "errors" + "fmt" + "net/http" + "net/url" + "slices" + "strconv" + "strings" + "time" + + "github.com/bluesky-social/indigo/cmd/relay/models" + "github.com/labstack/echo/v4" + dto "github.com/prometheus/client_model/go" + "gorm.io/gorm" +) + +func (bgs *BGS) handleAdminSetSubsEnabled(e echo.Context) error { + enabled, err := strconv.ParseBool(e.QueryParam("enabled")) + if err != nil { + return &echo.HTTPError{ + Code: 400, + Message: err.Error(), + } + } + + return bgs.slurper.SetNewSubsDisabled(!enabled) +} + +func (bgs *BGS) handleAdminGetSubsEnabled(e echo.Context) error { + return e.JSON(200, map[string]bool{ + "enabled": !bgs.slurper.GetNewSubsDisabledState(), + }) +} + +func (bgs *BGS) handleAdminGetNewPDSPerDayRateLimit(e echo.Context) error { + limit := bgs.slurper.GetNewPDSPerDayLimit() + return e.JSON(200, map[string]int64{ + "limit": limit, + }) +} + +func (bgs *BGS) handleAdminSetNewPDSPerDayRateLimit(e echo.Context) error { + limit, err := strconv.ParseInt(e.QueryParam("limit"), 10, 64) + if err != nil { + return &echo.HTTPError{ + Code: 400, + Message: fmt.Errorf("failed to parse limit: %w", err).Error(), + } + } + + err = bgs.slurper.SetNewPDSPerDayLimit(limit) + if err != nil { + return &echo.HTTPError{ + Code: 500, + Message: fmt.Errorf("failed to set new PDS per day rate limit: %w", err).Error(), + } + } + + return nil +} + +func (bgs *BGS) handleAdminTakeDownRepo(e echo.Context) error { + ctx := e.Request().Context() + + var body map[string]string + if err := e.Bind(&body); err != nil { + return err + } + did, ok := body["did"] + if !ok { + return &echo.HTTPError{ + Code: 400, + Message: "must specify did parameter in body", + } + } + + err := bgs.TakeDownRepo(ctx, did) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return &echo.HTTPError{ + Code: http.StatusNotFound, + Message: "repo not found", + } + } + return &echo.HTTPError{ + Code: http.StatusInternalServerError, + Message: err.Error(), + } + } + return nil +} + +func (bgs *BGS) handleAdminReverseTakedown(e echo.Context) error { + did := e.QueryParam("did") + ctx := e.Request().Context() + err := bgs.ReverseTakedown(ctx, did) + + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return &echo.HTTPError{ + Code: http.StatusNotFound, + Message: "repo not found", + } + } + return &echo.HTTPError{ + Code: http.StatusInternalServerError, + Message: err.Error(), + } + } + + return nil +} + +type ListTakedownsResponse struct { + Dids []string `json:"dids"` + Cursor int64 `json:"cursor,omitempty"` +} + +func (bgs *BGS) handleAdminListRepoTakeDowns(e echo.Context) error { + ctx := e.Request().Context() + haveMinId := false + minId := int64(-1) + qmin := e.QueryParam("cursor") + if qmin != "" { + tmin, err := strconv.ParseInt(qmin, 10, 64) + if err != nil { + return &echo.HTTPError{Code: 400, Message: "bad cursor"} + } + minId = tmin + haveMinId = true + } + limit := 1000 + wat := bgs.db.Model(User{}).WithContext(ctx).Select("id", "did").Where("taken_down = TRUE") + if haveMinId { + wat = wat.Where("id > ?", minId) + } + //var users []User + rows, err := wat.Order("id").Limit(limit).Rows() + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "oops").WithInternal(err) + } + var out ListTakedownsResponse + for rows.Next() { + var id int64 + var did string + err := rows.Scan(&id, &did) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "oops").WithInternal(err) + } + out.Dids = append(out.Dids, did) + out.Cursor = id + } + if len(out.Dids) < limit { + out.Cursor = 0 + } + return e.JSON(200, out) +} + +func (bgs *BGS) handleAdminGetUpstreamConns(e echo.Context) error { + return e.JSON(200, bgs.slurper.GetActiveList()) +} + +type rateLimit struct { + Max float64 `json:"Max"` + WindowSeconds float64 `json:"Window"` +} + +type enrichedPDS struct { + models.PDS + HasActiveConnection bool `json:"HasActiveConnection"` + EventsSeenSinceStartup uint64 `json:"EventsSeenSinceStartup"` + PerSecondEventRate rateLimit `json:"PerSecondEventRate"` + PerHourEventRate rateLimit `json:"PerHourEventRate"` + PerDayEventRate rateLimit `json:"PerDayEventRate"` + UserCount int64 `json:"UserCount"` +} + +type UserCount struct { + PDSID uint `gorm:"column:pds"` + UserCount int64 `gorm:"column:user_count"` +} + +func (bgs *BGS) handleListPDSs(e echo.Context) error { + var pds []models.PDS + if err := bgs.db.Find(&pds).Error; err != nil { + return err + } + + enrichedPDSs := make([]enrichedPDS, len(pds)) + + activePDSHosts := bgs.slurper.GetActiveList() + + for i, p := range pds { + enrichedPDSs[i].PDS = p + enrichedPDSs[i].HasActiveConnection = false + for _, host := range activePDSHosts { + if strings.ToLower(host) == strings.ToLower(p.Host) { + enrichedPDSs[i].HasActiveConnection = true + break + } + } + var m = &dto.Metric{} + if err := eventsReceivedCounter.WithLabelValues(p.Host).Write(m); err != nil { + enrichedPDSs[i].EventsSeenSinceStartup = 0 + continue + } + enrichedPDSs[i].EventsSeenSinceStartup = uint64(m.Counter.GetValue()) + + enrichedPDSs[i].PerSecondEventRate = rateLimit{ + Max: p.RateLimit, + WindowSeconds: 1, + } + + enrichedPDSs[i].PerHourEventRate = rateLimit{ + Max: float64(p.HourlyEventLimit), + WindowSeconds: 3600, + } + + enrichedPDSs[i].PerDayEventRate = rateLimit{ + Max: float64(p.DailyEventLimit), + WindowSeconds: 86400, + } + } + + return e.JSON(200, enrichedPDSs) +} + +type consumer struct { + ID uint64 `json:"id"` + RemoteAddr string `json:"remote_addr"` + UserAgent string `json:"user_agent"` + EventsConsumed uint64 `json:"events_consumed"` + ConnectedAt time.Time `json:"connected_at"` +} + +func (bgs *BGS) handleAdminListConsumers(e echo.Context) error { + bgs.consumersLk.RLock() + defer bgs.consumersLk.RUnlock() + + consumers := make([]consumer, 0, len(bgs.consumers)) + for id, c := range bgs.consumers { + var m = &dto.Metric{} + if err := c.EventsSent.Write(m); err != nil { + continue + } + consumers = append(consumers, consumer{ + ID: id, + RemoteAddr: c.RemoteAddr, + UserAgent: c.UserAgent, + EventsConsumed: uint64(m.Counter.GetValue()), + ConnectedAt: c.ConnectedAt, + }) + } + + return e.JSON(200, consumers) +} + +func (bgs *BGS) handleAdminKillUpstreamConn(e echo.Context) error { + host := strings.TrimSpace(e.QueryParam("host")) + if host == "" { + return &echo.HTTPError{ + Code: 400, + Message: "must pass a valid host", + } + } + + block := strings.ToLower(e.QueryParam("block")) == "true" + + if err := bgs.slurper.KillUpstreamConnection(host, block); err != nil { + if errors.Is(err, ErrNoActiveConnection) { + return &echo.HTTPError{ + Code: 400, + Message: "no active connection to given host", + } + } + return err + } + + return e.JSON(200, map[string]any{ + "success": "true", + }) +} + +func (bgs *BGS) handleBlockPDS(e echo.Context) error { + host := strings.TrimSpace(e.QueryParam("host")) + if host == "" { + return &echo.HTTPError{ + Code: 400, + Message: "must pass a valid host", + } + } + + // Set the block flag to true in the DB + if err := bgs.db.Model(&models.PDS{}).Where("host = ?", host).Update("blocked", true).Error; err != nil { + return err + } + + // don't care if this errors, but we should try to disconnect something we just blocked + _ = bgs.slurper.KillUpstreamConnection(host, false) + + return e.JSON(200, map[string]any{ + "success": "true", + }) +} + +func (bgs *BGS) handleUnblockPDS(e echo.Context) error { + host := strings.TrimSpace(e.QueryParam("host")) + if host == "" { + return &echo.HTTPError{ + Code: 400, + Message: "must pass a valid host", + } + } + + // Set the block flag to false in the DB + if err := bgs.db.Model(&models.PDS{}).Where("host = ?", host).Update("blocked", false).Error; err != nil { + return err + } + + return e.JSON(200, map[string]any{ + "success": "true", + }) +} + +type bannedDomains struct { + BannedDomains []string `json:"banned_domains"` +} + +func (bgs *BGS) handleAdminListDomainBans(c echo.Context) error { + var all []DomainBan + if err := bgs.db.Find(&all).Error; err != nil { + return err + } + + resp := bannedDomains{ + BannedDomains: []string{}, + } + for _, b := range all { + resp.BannedDomains = append(resp.BannedDomains, b.Domain) + } + + return c.JSON(200, resp) +} + +type banDomainBody struct { + Domain string +} + +func (bgs *BGS) handleAdminBanDomain(c echo.Context) error { + var body banDomainBody + if err := c.Bind(&body); err != nil { + return err + } + + // Check if the domain is already banned + var existing DomainBan + if err := bgs.db.Where("domain = ?", body.Domain).First(&existing).Error; err == nil { + return &echo.HTTPError{ + Code: 400, + Message: "domain is already banned", + } + } + + if err := bgs.db.Create(&DomainBan{ + Domain: body.Domain, + }).Error; err != nil { + return err + } + + return c.JSON(200, map[string]any{ + "success": "true", + }) +} + +func (bgs *BGS) handleAdminUnbanDomain(c echo.Context) error { + var body banDomainBody + if err := c.Bind(&body); err != nil { + return err + } + + if err := bgs.db.Where("domain = ?", body.Domain).Delete(&DomainBan{}).Error; err != nil { + return err + } + + return c.JSON(200, map[string]any{ + "success": "true", + }) +} + +type PDSRates struct { + // core event rate, counts firehose events + PerSecond int64 `json:"per_second,omitempty"` + PerHour int64 `json:"per_hour,omitempty"` + PerDay int64 `json:"per_day,omitempty"` + + RepoLimit int64 `json:"repo_limit,omitempty"` +} + +func (pr *PDSRates) FromSlurper(s *Slurper) { + if pr.PerSecond == 0 { + pr.PerHour = s.DefaultPerSecondLimit + } + if pr.PerHour == 0 { + pr.PerHour = s.DefaultPerHourLimit + } + if pr.PerDay == 0 { + pr.PerDay = s.DefaultPerDayLimit + } + if pr.RepoLimit == 0 { + pr.RepoLimit = s.DefaultRepoLimit + } +} + +type RateLimitChangeRequest struct { + Host string `json:"host"` + PDSRates +} + +func (bgs *BGS) handleAdminChangePDSRateLimits(e echo.Context) error { + var body RateLimitChangeRequest + if err := e.Bind(&body); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("invalid body: %s", err)) + } + + // Get the PDS from the DB + var pds models.PDS + if err := bgs.db.Where("host = ?", body.Host).First(&pds).Error; err != nil { + return err + } + + // Update the rate limits in the DB + pds.RateLimit = float64(body.PerSecond) + pds.HourlyEventLimit = body.PerHour + pds.DailyEventLimit = body.PerDay + pds.RepoLimit = body.RepoLimit + + if err := bgs.db.Save(&pds).Error; err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to save rate limit changes: %w", err)) + } + + // Update the rate limit in the limiter + limits := bgs.slurper.GetOrCreateLimiters(pds.ID, body.PerSecond, body.PerHour, body.PerDay) + limits.PerSecond.SetLimit(body.PerSecond) + limits.PerHour.SetLimit(body.PerHour) + limits.PerDay.SetLimit(body.PerDay) + + return e.JSON(200, map[string]any{ + "success": "true", + }) +} + +func (bgs *BGS) handleAdminAddTrustedDomain(e echo.Context) error { + domain := e.QueryParam("domain") + if domain == "" { + return fmt.Errorf("must specify domain in query parameter") + } + + // Check if the domain is already trusted + trustedDomains := bgs.slurper.GetTrustedDomains() + if slices.Contains(trustedDomains, domain) { + return &echo.HTTPError{ + Code: 400, + Message: "domain is already trusted", + } + } + + if err := bgs.slurper.AddTrustedDomain(domain); err != nil { + return err + } + + return e.JSON(200, map[string]any{ + "success": true, + }) +} + +type AdminRequestCrawlRequest struct { + Hostname string `json:"hostname"` + + // optional: + PDSRates +} + +func (bgs *BGS) handleAdminRequestCrawl(e echo.Context) error { + ctx := e.Request().Context() + + var body AdminRequestCrawlRequest + if err := e.Bind(&body); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("invalid body: %s", err)) + } + + host := body.Hostname + if host == "" { + return echo.NewHTTPError(http.StatusBadRequest, "must pass hostname") + } + + if !strings.HasPrefix(host, "http://") && !strings.HasPrefix(host, "https://") { + if bgs.ssl { + host = "https://" + host + } else { + host = "http://" + host + } + } + + u, err := url.Parse(host) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, "failed to parse hostname") + } + + if u.Scheme == "http" && bgs.ssl { + return echo.NewHTTPError(http.StatusBadRequest, "this server requires https") + } + + if u.Scheme == "https" && !bgs.ssl { + return echo.NewHTTPError(http.StatusBadRequest, "this server does not support https") + } + + if u.Path != "" { + return echo.NewHTTPError(http.StatusBadRequest, "must pass hostname without path") + } + + if u.Query().Encode() != "" { + return echo.NewHTTPError(http.StatusBadRequest, "must pass hostname without query") + } + + host = u.Host // potentially hostname:port + + banned, err := bgs.domainIsBanned(ctx, host) + if banned { + return echo.NewHTTPError(http.StatusUnauthorized, "domain is banned") + } + + // Skip checking if the server is online for now + rateOverrides := body.PDSRates + rateOverrides.FromSlurper(bgs.slurper) + + return bgs.slurper.SubscribeToPds(ctx, host, true, true, &rateOverrides) // Override Trusted Domain Check +} diff --git a/cmd/relay/bgs/bgs.go b/cmd/relay/bgs/bgs.go new file mode 100644 index 000000000..f94e6b53d --- /dev/null +++ b/cmd/relay/bgs/bgs.go @@ -0,0 +1,1376 @@ +package bgs + +import ( + "context" + "database/sql" + "errors" + "fmt" + "github.com/bluesky-social/indigo/atproto/identity" + "github.com/bluesky-social/indigo/atproto/syntax" + "github.com/ipfs/go-cid" + "io" + "log/slog" + "net" + "net/http" + _ "net/http/pprof" + "net/url" + "strconv" + "strings" + "sync" + "time" + + comatproto "github.com/bluesky-social/indigo/api/atproto" + "github.com/bluesky-social/indigo/cmd/relay/events" + "github.com/bluesky-social/indigo/cmd/relay/models" + "github.com/bluesky-social/indigo/cmd/relay/repomgr" + lexutil "github.com/bluesky-social/indigo/lex/util" + "github.com/bluesky-social/indigo/xrpc" + + "github.com/gorilla/websocket" + lru "github.com/hashicorp/golang-lru/v2" + "github.com/labstack/echo/v4" + "github.com/labstack/echo/v4/middleware" + promclient "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promhttp" + dto "github.com/prometheus/client_model/go" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "gorm.io/gorm" +) + +var tracer = otel.Tracer("bgs") + +// serverListenerBootTimeout is how long to wait for the requested server socket +// to become available for use. This is an arbitrary timeout that should be safe +// on any platform, but there's no great way to weave this timeout without +// adding another parameter to the (at time of writing) long signature of +// NewServer. +const serverListenerBootTimeout = 5 * time.Second + +type BGS struct { + db *gorm.DB + slurper *Slurper + events *events.EventManager + didd identity.Directory + + // TODO: work on doing away with this flag in favor of more pluggable + // pieces that abstract the need for explicit ssl checks + ssl bool + + crawlOnly bool + + // TODO: at some point we will want to lock specific DIDs, this lock as is + // is overly broad, but i dont expect it to be a bottleneck for now + extUserLk sync.Mutex + + repoman *repomgr.RepoManager + + // Management of Socket Consumers + consumersLk sync.RWMutex + nextConsumerID uint64 + consumers map[uint64]*SocketConsumer + + // User cache + userCache *lru.Cache[string, *User] + + // nextCrawlers gets forwarded POST /xrpc/com.atproto.sync.requestCrawl + nextCrawlers []*url.URL + httpClient http.Client + + log *slog.Logger + inductionTraceLog *slog.Logger + + config BGSConfig +} + +type SocketConsumer struct { + UserAgent string + RemoteAddr string + ConnectedAt time.Time + EventsSent promclient.Counter +} + +type BGSConfig struct { + SSL bool + DefaultRepoLimit int64 + ConcurrencyPerPDS int64 + MaxQueuePerPDS int64 + + // NextCrawlers gets forwarded POST /xrpc/com.atproto.sync.requestCrawl + NextCrawlers []*url.URL + + ApplyPDSClientSettings func(c *xrpc.Client) + InductionTraceLog *slog.Logger +} + +func DefaultBGSConfig() *BGSConfig { + return &BGSConfig{ + SSL: true, + DefaultRepoLimit: 100, + ConcurrencyPerPDS: 100, + MaxQueuePerPDS: 1_000, + } +} + +func NewBGS(db *gorm.DB, repoman *repomgr.RepoManager, evtman *events.EventManager, didd identity.Directory, config *BGSConfig) (*BGS, error) { + + if config == nil { + config = DefaultBGSConfig() + } + if err := db.AutoMigrate(AuthToken{}); err != nil { + panic(err) + } + if err := db.AutoMigrate(DomainBan{}); err != nil { + panic(err) + } + if err := db.AutoMigrate(models.PDS{}); err != nil { + panic(err) + } + if err := db.AutoMigrate(User{}); err != nil { + panic(err) + } + if err := db.AutoMigrate(UserPreviousState{}); err != nil { + panic(err) + } + + uc, _ := lru.New[string, *User](1_000_000) + + bgs := &BGS{ + db: db, + + repoman: repoman, + events: evtman, + didd: didd, + ssl: config.SSL, + + consumersLk: sync.RWMutex{}, + consumers: make(map[uint64]*SocketConsumer), + + userCache: uc, + + log: slog.Default().With("system", "bgs"), + + config: *config, + + inductionTraceLog: config.InductionTraceLog, + } + + slOpts := DefaultSlurperOptions() + slOpts.SSL = config.SSL + slOpts.DefaultRepoLimit = config.DefaultRepoLimit + slOpts.ConcurrencyPerPDS = config.ConcurrencyPerPDS + slOpts.MaxQueuePerPDS = config.MaxQueuePerPDS + slOpts.Logger = bgs.log + s, err := NewSlurper(db, bgs.handleFedEvent, slOpts) + if err != nil { + return nil, err + } + + bgs.slurper = s + + if err := bgs.slurper.RestartAll(); err != nil { + return nil, err + } + + bgs.nextCrawlers = config.NextCrawlers + bgs.httpClient.Timeout = time.Second * 5 + + return bgs, nil +} + +func (bgs *BGS) StartMetrics(listen string) error { + http.Handle("/metrics", promhttp.Handler()) + return http.ListenAndServe(listen, nil) +} + +func (bgs *BGS) Start(addr string, logWriter io.Writer) error { + var lc net.ListenConfig + ctx, cancel := context.WithTimeout(context.Background(), serverListenerBootTimeout) + defer cancel() + + li, err := lc.Listen(ctx, "tcp", addr) + if err != nil { + return err + } + return bgs.StartWithListener(li, logWriter) +} + +func (bgs *BGS) StartWithListener(listen net.Listener, logWriter io.Writer) error { + e := echo.New() + e.Logger.SetOutput(logWriter) + e.HideBanner = true + + e.Use(middleware.CORSWithConfig(middleware.CORSConfig{ + AllowOrigins: []string{"*"}, + AllowHeaders: []string{echo.HeaderOrigin, echo.HeaderContentType, echo.HeaderAccept, echo.HeaderAuthorization}, + })) + + if !bgs.ssl { + e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{ + Format: "method=${method}, uri=${uri}, status=${status} latency=${latency_human}\n", + })) + } else { + e.Use(middleware.LoggerWithConfig(middleware.DefaultLoggerConfig)) + } + + // React uses a virtual router, so we need to serve the index.html for all + // routes that aren't otherwise handled or in the /assets directory. + e.File("/dash", "public/index.html") + e.File("/dash/*", "public/index.html") + e.Static("/assets", "public/assets") + + e.Use(MetricsMiddleware) + + e.HTTPErrorHandler = func(err error, ctx echo.Context) { + switch err := err.(type) { + case *echo.HTTPError: + if err2 := ctx.JSON(err.Code, map[string]any{ + "error": err.Message, + }); err2 != nil { + bgs.log.Error("Failed to write http error", "err", err2) + } + default: + sendHeader := true + if ctx.Path() == "/xrpc/com.atproto.sync.subscribeRepos" { + sendHeader = false + } + + bgs.log.Warn("HANDLER ERROR: (%s) %s", ctx.Path(), err) + + if strings.HasPrefix(ctx.Path(), "/admin/") { + ctx.JSON(500, map[string]any{ + "error": err.Error(), + }) + return + } + + if sendHeader { + ctx.Response().WriteHeader(500) + } + } + } + + // TODO: this API is temporary until we formalize what we want here + + e.GET("/xrpc/com.atproto.sync.subscribeRepos", bgs.EventsHandler) + e.POST("/xrpc/com.atproto.sync.requestCrawl", bgs.HandleComAtprotoSyncRequestCrawl) + e.GET("/xrpc/com.atproto.sync.listRepos", bgs.HandleComAtprotoSyncListRepos) + e.GET("/xrpc/com.atproto.sync.getRepo", bgs.HandleComAtprotoSyncGetRepo) // just returns 3xx redirect to source PDS + e.GET("/xrpc/com.atproto.sync.getLatestCommit", bgs.HandleComAtprotoSyncGetLatestCommit) + e.GET("/xrpc/_health", bgs.HandleHealthCheck) + e.GET("/_health", bgs.HandleHealthCheck) + e.GET("/", bgs.HandleHomeMessage) + + admin := e.Group("/admin", bgs.checkAdminAuth) + + // Slurper-related Admin API + admin.GET("/subs/getUpstreamConns", bgs.handleAdminGetUpstreamConns) + admin.GET("/subs/getEnabled", bgs.handleAdminGetSubsEnabled) + admin.GET("/subs/perDayLimit", bgs.handleAdminGetNewPDSPerDayRateLimit) + admin.POST("/subs/setEnabled", bgs.handleAdminSetSubsEnabled) + admin.POST("/subs/killUpstream", bgs.handleAdminKillUpstreamConn) + admin.POST("/subs/setPerDayLimit", bgs.handleAdminSetNewPDSPerDayRateLimit) + + // Domain-related Admin API + admin.GET("/subs/listDomainBans", bgs.handleAdminListDomainBans) + admin.POST("/subs/banDomain", bgs.handleAdminBanDomain) + admin.POST("/subs/unbanDomain", bgs.handleAdminUnbanDomain) + + // Repo-related Admin API + admin.POST("/repo/takeDown", bgs.handleAdminTakeDownRepo) + admin.POST("/repo/reverseTakedown", bgs.handleAdminReverseTakedown) + admin.GET("/repo/takedowns", bgs.handleAdminListRepoTakeDowns) + + // PDS-related Admin API + admin.POST("/pds/requestCrawl", bgs.handleAdminRequestCrawl) + admin.GET("/pds/list", bgs.handleListPDSs) + admin.POST("/pds/changeLimits", bgs.handleAdminChangePDSRateLimits) + admin.POST("/pds/block", bgs.handleBlockPDS) + admin.POST("/pds/unblock", bgs.handleUnblockPDS) + admin.POST("/pds/addTrustedDomain", bgs.handleAdminAddTrustedDomain) + + // Consumer-related Admin API + admin.GET("/consumers/list", bgs.handleAdminListConsumers) + + // In order to support booting on random ports in tests, we need to tell the + // Echo instance it's already got a port, and then use its StartServer + // method to re-use that listener. + e.Listener = listen + srv := &http.Server{} + return e.StartServer(srv) +} + +func (bgs *BGS) Shutdown() []error { + errs := bgs.slurper.Shutdown() + + if err := bgs.events.Shutdown(context.TODO()); err != nil { + errs = append(errs, err) + } + + return errs +} + +type HealthStatus struct { + Status string `json:"status"` + Message string `json:"msg,omitempty"` +} + +func (bgs *BGS) HandleHealthCheck(c echo.Context) error { + if err := bgs.db.Exec("SELECT 1").Error; err != nil { + bgs.log.Error("healthcheck can't connect to database", "err", err) + return c.JSON(500, HealthStatus{Status: "error", Message: "can't connect to database"}) + } else { + return c.JSON(200, HealthStatus{Status: "ok"}) + } +} + +var homeMessage string = ` +d8888b. d888888b d888b .d8888. db dD db db +88 '8D '88' 88' Y8b 88' YP 88 ,8P' '8b d8' +88oooY' 88 88 '8bo. 88,8P '8bd8' +88~~~b. 88 88 ooo 'Y8b. 88'8b 88 +88 8D .88. 88. ~8~ db 8D 88 '88. 88 +Y8888P' Y888888P Y888P '8888Y' YP YD YP + +This is an atproto [https://atproto.com] relay instance, running the 'bigsky' codebase [https://github.com/bluesky-social/indigo] + +The firehose WebSocket path is at: /xrpc/com.atproto.sync.subscribeRepos +` + +func (bgs *BGS) HandleHomeMessage(c echo.Context) error { + return c.String(http.StatusOK, homeMessage) +} + +type AuthToken struct { + gorm.Model + Token string `gorm:"index"` +} + +func (bgs *BGS) lookupAdminToken(tok string) (bool, error) { + var at AuthToken + if err := bgs.db.Find(&at, "token = ?", tok).Error; err != nil { + return false, err + } + + if at.ID == 0 { + return false, nil + } + + return true, nil +} + +func (bgs *BGS) CreateAdminToken(tok string) error { + exists, err := bgs.lookupAdminToken(tok) + if err != nil { + return err + } + + if exists { + return nil + } + + return bgs.db.Create(&AuthToken{ + Token: tok, + }).Error +} + +func (bgs *BGS) checkAdminAuth(next echo.HandlerFunc) echo.HandlerFunc { + return func(e echo.Context) error { + ctx, span := tracer.Start(e.Request().Context(), "checkAdminAuth") + defer span.End() + + e.SetRequest(e.Request().WithContext(ctx)) + + authheader := e.Request().Header.Get("Authorization") + pref := "Bearer " + if !strings.HasPrefix(authheader, pref) { + return echo.ErrForbidden + } + + token := authheader[len(pref):] + + exists, err := bgs.lookupAdminToken(token) + if err != nil { + return err + } + + if !exists { + return echo.ErrForbidden + } + + return next(e) + } +} + +type User struct { + ID models.Uid `gorm:"primarykey;index:idx_user_id_active,where:taken_down = false AND tombstoned = false"` + CreatedAt time.Time + UpdatedAt time.Time + DeletedAt gorm.DeletedAt `gorm:"index"` + Handle sql.NullString `gorm:"index"` + Did string `gorm:"uniqueIndex"` + PDS uint + ValidHandle bool `gorm:"default:true"` + + // TakenDown is set to true if the user in question has been taken down. + // A user in this state will have all future events related to it dropped + // and no data about this user will be served. + TakenDown bool + Tombstoned bool + + // UpstreamStatus is the state of the user as reported by the upstream PDS + UpstreamStatus string `gorm:"index"` + + lk sync.Mutex +} + +func (u *User) GetDid() string { + return u.Did +} + +func (u *User) GetUid() models.Uid { + return u.ID +} + +func (u *User) SetTakenDown(v bool) { + u.lk.Lock() + defer u.lk.Unlock() + u.TakenDown = v +} + +func (u *User) GetTakenDown() bool { + u.lk.Lock() + defer u.lk.Unlock() + return u.TakenDown +} + +func (u *User) SetTombstoned(v bool) { + u.lk.Lock() + defer u.lk.Unlock() + u.Tombstoned = v +} + +func (u *User) GetTombstoned() bool { + u.lk.Lock() + defer u.lk.Unlock() + return u.Tombstoned +} + +func (u *User) SetUpstreamStatus(v string) { + u.lk.Lock() + defer u.lk.Unlock() + u.UpstreamStatus = v +} + +func (u *User) GetUpstreamStatus() string { + u.lk.Lock() + defer u.lk.Unlock() + return u.UpstreamStatus +} + +type UserPreviousState struct { + Uid models.Uid `gorm:"column:uid;primaryKey"` + Cid models.DbCID `gorm:"column:cid"` + Rev string `gorm:"column:rev"` + Seq int64 `gorm:"column:seq"` +} + +func (ups *UserPreviousState) GetCid() cid.Cid { + return ups.Cid.CID +} +func (ups *UserPreviousState) GetRev() syntax.TID { + xt, _ := syntax.ParseTID(ups.Rev) + return xt +} + +type addTargetBody struct { + Host string `json:"host"` +} + +func (bgs *BGS) registerConsumer(c *SocketConsumer) uint64 { + bgs.consumersLk.Lock() + defer bgs.consumersLk.Unlock() + + id := bgs.nextConsumerID + bgs.nextConsumerID++ + + bgs.consumers[id] = c + + return id +} + +func (bgs *BGS) cleanupConsumer(id uint64) { + bgs.consumersLk.Lock() + defer bgs.consumersLk.Unlock() + + c := bgs.consumers[id] + + var m = &dto.Metric{} + if err := c.EventsSent.Write(m); err != nil { + bgs.log.Error("failed to get sent counter", "err", err) + } + + bgs.log.Info("consumer disconnected", + "consumer_id", id, + "remote_addr", c.RemoteAddr, + "user_agent", c.UserAgent, + "events_sent", m.Counter.GetValue()) + + delete(bgs.consumers, id) +} + +// GET+websocket /xrpc/com.atproto.sync.subscribeRepos +func (bgs *BGS) EventsHandler(c echo.Context) error { + var since *int64 + if sinceVal := c.QueryParam("cursor"); sinceVal != "" { + sval, err := strconv.ParseInt(sinceVal, 10, 64) + if err != nil { + return err + } + since = &sval + } + + ctx, cancel := context.WithCancel(c.Request().Context()) + defer cancel() + + // TODO: authhhh + conn, err := websocket.Upgrade(c.Response(), c.Request(), c.Response().Header(), 10<<10, 10<<10) + if err != nil { + return fmt.Errorf("upgrading websocket: %w", err) + } + + defer conn.Close() + + lastWriteLk := sync.Mutex{} + lastWrite := time.Now() + + // Start a goroutine to ping the client every 30 seconds to check if it's + // still alive. If the client doesn't respond to a ping within 5 seconds, + // we'll close the connection and teardown the consumer. + go func() { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + lastWriteLk.Lock() + lw := lastWrite + lastWriteLk.Unlock() + + if time.Since(lw) < 30*time.Second { + continue + } + + if err := conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(5*time.Second)); err != nil { + bgs.log.Warn("failed to ping client", "err", err) + cancel() + return + } + case <-ctx.Done(): + return + } + } + }() + + conn.SetPingHandler(func(message string) error { + err := conn.WriteControl(websocket.PongMessage, []byte(message), time.Now().Add(time.Second*60)) + if err == websocket.ErrCloseSent { + return nil + } else if e, ok := err.(net.Error); ok && e.Temporary() { + return nil + } + return err + }) + + // Start a goroutine to read messages from the client and discard them. + go func() { + for { + _, _, err := conn.ReadMessage() + if err != nil { + bgs.log.Warn("failed to read message from client", "err", err) + cancel() + return + } + } + }() + + ident := c.RealIP() + "-" + c.Request().UserAgent() + + evts, cleanup, err := bgs.events.Subscribe(ctx, ident, func(evt *events.XRPCStreamEvent) bool { return true }, since) + if err != nil { + return err + } + defer cleanup() + + // Keep track of the consumer for metrics and admin endpoints + consumer := SocketConsumer{ + RemoteAddr: c.RealIP(), + UserAgent: c.Request().UserAgent(), + ConnectedAt: time.Now(), + } + sentCounter := eventsSentCounter.WithLabelValues(consumer.RemoteAddr, consumer.UserAgent) + consumer.EventsSent = sentCounter + + consumerID := bgs.registerConsumer(&consumer) + defer bgs.cleanupConsumer(consumerID) + + logger := bgs.log.With( + "consumer_id", consumerID, + "remote_addr", consumer.RemoteAddr, + "user_agent", consumer.UserAgent, + ) + + logger.Info("new consumer", "cursor", since) + + for { + select { + case evt, ok := <-evts: + if !ok { + logger.Error("event stream closed unexpectedly") + return nil + } + + wc, err := conn.NextWriter(websocket.BinaryMessage) + if err != nil { + logger.Error("failed to get next writer", "err", err) + return err + } + + if evt.Preserialized != nil { + _, err = wc.Write(evt.Preserialized) + } else { + err = evt.Serialize(wc) + } + if err != nil { + return fmt.Errorf("failed to write event: %w", err) + } + + if err := wc.Close(); err != nil { + logger.Warn("failed to flush-close our event write", "err", err) + return nil + } + + lastWriteLk.Lock() + lastWrite = time.Now() + lastWriteLk.Unlock() + sentCounter.Inc() + case <-ctx.Done(): + return nil + } + } +} + +// domainIsBanned checks if the given host is banned, starting with the host +// itself, then checking every parent domain up to the tld +func (s *BGS) domainIsBanned(ctx context.Context, host string) (bool, error) { + // ignore ports when checking for ban status + hostport := strings.Split(host, ":") + + segments := strings.Split(hostport[0], ".") + + // TODO: use normalize method once that merges + var cleaned []string + for _, s := range segments { + if s == "" { + continue + } + s = strings.ToLower(s) + + cleaned = append(cleaned, s) + } + segments = cleaned + + for i := 0; i < len(segments)-1; i++ { + dchk := strings.Join(segments[i:], ".") + found, err := s.findDomainBan(ctx, dchk) + if err != nil { + return false, err + } + + if found { + return true, nil + } + } + return false, nil +} + +func (s *BGS) findDomainBan(ctx context.Context, host string) (bool, error) { + var db DomainBan + if err := s.db.Find(&db, "domain = ?", host).Error; err != nil { + return false, err + } + + if db.ID == 0 { + return false, nil + } + + return true, nil +} + +var ErrNotFound = errors.New("not found") + +func (bgs *BGS) DidToUid(ctx context.Context, did string) (models.Uid, error) { + xu, err := bgs.lookupUserByDid(ctx, did) + if err != nil { + return 0, err + } + if xu == nil { + return 0, ErrNotFound + } + return xu.ID, nil +} + +func (bgs *BGS) lookupUserByDid(ctx context.Context, did string) (*User, error) { + ctx, span := tracer.Start(ctx, "lookupUserByDid") + defer span.End() + + cu, ok := bgs.userCache.Get(did) + if ok { + return cu, nil + } + + var u User + if err := bgs.db.Find(&u, "did = ?", did).Error; err != nil { + return nil, err + } + + if u.ID == 0 { + return nil, gorm.ErrRecordNotFound + } + + bgs.userCache.Add(did, &u) + + return &u, nil +} + +func (bgs *BGS) lookupUserByUID(ctx context.Context, uid models.Uid) (*User, error) { + ctx, span := tracer.Start(ctx, "lookupUserByUID") + defer span.End() + + var u User + if err := bgs.db.Find(&u, "id = ?", uid).Error; err != nil { + return nil, err + } + + if u.ID == 0 { + return nil, gorm.ErrRecordNotFound + } + + return &u, nil +} + +func stringLink(lnk *lexutil.LexLink) string { + if lnk == nil { + return "" + } + + return lnk.String() +} + +// handleFedEvent() is the callback passed to Slurper called from Slurper.handleConnection() +func (bgs *BGS) handleFedEvent(ctx context.Context, host *models.PDS, env *events.XRPCStreamEvent) error { + ctx, span := tracer.Start(ctx, "handleFedEvent") + defer span.End() + + start := time.Now() + defer func() { + eventsHandleDuration.WithLabelValues(host.Host).Observe(time.Since(start).Seconds()) + }() + + eventsReceivedCounter.WithLabelValues(host.Host).Add(1) + + switch { + case env.RepoCommit != nil: + repoCommitsReceivedCounter.WithLabelValues(host.Host).Add(1) + return bgs.handleCommit(ctx, host, env.RepoCommit) + case env.RepoSync != nil: + return bgs.handleSync(ctx, host, env.RepoSync) + case env.RepoHandle != nil: + // TODO: DEPRECATED - expect Identity message below instead + bgs.log.Info("bgs got repo handle event", "did", env.RepoHandle.Did, "handle", env.RepoHandle.Handle) + // Flush any cached DID documents for this user + bgs.purgeDidCache(ctx, env.RepoHandle.Did) + + // TODO: ignoring the data in the message and just going out to the DID doc + act, err := bgs.createExternalUser(ctx, env.RepoHandle.Did, host) + if err != nil { + return err + } + + if act.Handle.String != env.RepoHandle.Handle { + bgs.log.Warn("handle update did not update handle to asserted value", "did", env.RepoHandle.Did, "expected", env.RepoHandle.Handle, "actual", act.Handle) + } + + // TODO: Update the ReposHandle event type to include "verified" or something + + // Broadcast the handle update to all consumers + err = bgs.events.AddEvent(ctx, &events.XRPCStreamEvent{ + RepoHandle: &comatproto.SyncSubscribeRepos_Handle{ + Did: env.RepoHandle.Did, + Handle: env.RepoHandle.Handle, + Time: env.RepoHandle.Time, + }, + }) + if err != nil { + bgs.log.Error("failed to broadcast RepoHandle event", "error", err, "did", env.RepoHandle.Did, "handle", env.RepoHandle.Handle) + return fmt.Errorf("failed to broadcast RepoHandle event: %w", err) + } + + return nil + case env.RepoIdentity != nil: + bgs.log.Info("bgs got identity event", "did", env.RepoIdentity.Did) + // Flush any cached DID documents for this user + bgs.purgeDidCache(ctx, env.RepoIdentity.Did) + + // Refetch the DID doc and update our cached keys and handle etc. + _, err := bgs.createExternalUser(ctx, env.RepoIdentity.Did, host) + if err != nil { + return err + } + + // Broadcast the identity event to all consumers + err = bgs.events.AddEvent(ctx, &events.XRPCStreamEvent{ + RepoIdentity: &comatproto.SyncSubscribeRepos_Identity{ + Did: env.RepoIdentity.Did, + Seq: env.RepoIdentity.Seq, + Time: env.RepoIdentity.Time, + Handle: env.RepoIdentity.Handle, + }, + }) + if err != nil { + bgs.log.Error("failed to broadcast Identity event", "error", err, "did", env.RepoIdentity.Did) + return fmt.Errorf("failed to broadcast Identity event: %w", err) + } + + return nil + case env.RepoAccount != nil: + span.SetAttributes( + attribute.String("did", env.RepoAccount.Did), + attribute.Int64("seq", env.RepoAccount.Seq), + attribute.Bool("active", env.RepoAccount.Active), + ) + + if env.RepoAccount.Status != nil { + span.SetAttributes(attribute.String("repo_status", *env.RepoAccount.Status)) + } + + bgs.log.Info("bgs got account event", "did", env.RepoAccount.Did) + // Flush any cached DID documents for this user + bgs.purgeDidCache(ctx, env.RepoAccount.Did) + + // Refetch the DID doc to make sure the PDS is still authoritative + ai, err := bgs.createExternalUser(ctx, env.RepoAccount.Did, host) + if err != nil { + span.RecordError(err) + return err + } + + // Check if the PDS is still authoritative + // if not we don't want to be propagating this account event + if ai.PDS != host.ID { + bgs.log.Error("account event from non-authoritative pds", + "seq", env.RepoAccount.Seq, + "did", env.RepoAccount.Did, + "event_from", host.Host, + "did_doc_declared_pds", ai.PDS, + "account_evt", env.RepoAccount, + ) + return fmt.Errorf("event from non-authoritative pds") + } + + // Process the account status change + repoStatus := events.AccountStatusActive + if !env.RepoAccount.Active && env.RepoAccount.Status != nil { + repoStatus = *env.RepoAccount.Status + } + + err = bgs.UpdateAccountStatus(ctx, env.RepoAccount.Did, repoStatus) + if err != nil { + span.RecordError(err) + return fmt.Errorf("failed to update account status: %w", err) + } + + shouldBeActive := env.RepoAccount.Active + status := env.RepoAccount.Status + u, err := bgs.lookupUserByDid(ctx, env.RepoAccount.Did) + if err != nil { + return fmt.Errorf("failed to look up user by did: %w", err) + } + + if u.GetTakenDown() { + shouldBeActive = false + status = &events.AccountStatusTakendown + } + + // Broadcast the account event to all consumers + err = bgs.events.AddEvent(ctx, &events.XRPCStreamEvent{ + RepoAccount: &comatproto.SyncSubscribeRepos_Account{ + Did: env.RepoAccount.Did, + Seq: env.RepoAccount.Seq, + Time: env.RepoAccount.Time, + Active: shouldBeActive, + Status: status, + }, + }) + if err != nil { + bgs.log.Error("failed to broadcast Account event", "error", err, "did", env.RepoAccount.Did) + return fmt.Errorf("failed to broadcast Account event: %w", err) + } + + return nil + case env.RepoMigrate != nil: + // TODO: DEPRECATED - expect account message above instead + if _, err := bgs.createExternalUser(ctx, env.RepoMigrate.Did, host); err != nil { + return err + } + + return nil + case env.RepoTombstone != nil: + // TODO: DEPRECATED - expect account message above instead + if err := bgs.handleRepoTombstone(ctx, host, env.RepoTombstone); err != nil { + return err + } + + return nil + default: + return fmt.Errorf("invalid fed event") + } +} + +func (bgs *BGS) handleCommit(ctx context.Context, host *models.PDS, evt *comatproto.SyncSubscribeRepos_Commit) error { + bgs.log.Debug("bgs got repo append event", "seq", evt.Seq, "pdsHost", host.Host, "repo", evt.Repo) + + //s := time.Now() + u, err := bgs.lookupUserByDid(ctx, evt.Repo) + //userLookupDuration.Observe(time.Since(s).Seconds()) + if err != nil { + if !errors.Is(err, gorm.ErrRecordNotFound) { + repoCommitsResultCounter.WithLabelValues(host.Host, "nou").Inc() + return fmt.Errorf("looking up event user: %w", err) + } + + newUsersDiscovered.Inc() + start := time.Now() + subj, err := bgs.createExternalUser(ctx, evt.Repo, host) + newUserDiscoveryDuration.Observe(time.Since(start).Seconds()) + if err != nil { + repoCommitsResultCounter.WithLabelValues(host.Host, "uerr").Inc() + return fmt.Errorf("fed event create external user: %w", err) + } + + u = subj + } + + ustatus := u.GetUpstreamStatus() + //span.SetAttributes(attribute.String("upstream_status", ustatus)) + + if u.GetTakenDown() || ustatus == events.AccountStatusTakendown { + //span.SetAttributes(attribute.Bool("taken_down_by_relay_admin", u.GetTakenDown())) + bgs.log.Debug("dropping commit event from taken down user", "did", evt.Repo, "seq", evt.Seq, "pdsHost", host.Host) + repoCommitsResultCounter.WithLabelValues(host.Host, "tdu").Inc() + return nil + } + + if ustatus == events.AccountStatusSuspended { + bgs.log.Debug("dropping commit event from suspended user", "did", evt.Repo, "seq", evt.Seq, "pdsHost", host.Host) + repoCommitsResultCounter.WithLabelValues(host.Host, "susu").Inc() + return nil + } + + if ustatus == events.AccountStatusDeactivated { + bgs.log.Debug("dropping commit event from deactivated user", "did", evt.Repo, "seq", evt.Seq, "pdsHost", host.Host) + repoCommitsResultCounter.WithLabelValues(host.Host, "du").Inc() + return nil + } + + if evt.Rebase { + repoCommitsResultCounter.WithLabelValues(host.Host, "rebase").Inc() + return fmt.Errorf("rebase was true in event seq:%d,host:%s", evt.Seq, host.Host) + } + + if host.ID != u.PDS && u.PDS != 0 { + bgs.log.Warn("received event for repo from different pds than expected", "repo", evt.Repo, "expPds", u.PDS, "gotPds", host.Host) + // Flush any cached DID documents for this user + bgs.purgeDidCache(ctx, evt.Repo) + + subj, err := bgs.createExternalUser(ctx, evt.Repo, host) + if err != nil { + repoCommitsResultCounter.WithLabelValues(host.Host, "uerr2").Inc() + return err + } + + if subj.PDS != host.ID { + repoCommitsResultCounter.WithLabelValues(host.Host, "noauth").Inc() + return fmt.Errorf("event from non-authoritative pds") + } + } + + if u.GetTombstoned() { + // TODO: reevaluate user lifecycle - tombstoned -- bolson 2025 + + //span.SetAttributes(attribute.Bool("tombstoned", true)) + // we've checked the authority of the users PDS, so reinstate the account + if err := bgs.db.Model(&User{}).Where("id = ?", u.ID).UpdateColumn("tombstoned", false).Error; err != nil { + repoCommitsResultCounter.WithLabelValues(host.Host, "tomb").Inc() + return fmt.Errorf("failed to un-tombstone a user: %w", err) + } + u.SetTombstoned(false) + + //ai, err := bgs.Index.LookupUser(ctx, u.ID) + //if err != nil { + // repoCommitsResultCounter.WithLabelValues(host.Host, "nou2").Inc() + // return fmt.Errorf("failed to look up user (tombstone recover): %w", err) + //} + + // Now a simple re-crawl should suffice to bring the user back online + //repoCommitsResultCounter.WithLabelValues(host.Host, "catchupt").Inc() + //return bgs.Index.Crawler.AddToCatchupQueue(ctx, host, ai, evt) + // TODO: fall through and just handle the event and the right thing should happen? -- bolson 2025 unsure + } + + var prevState UserPreviousState + err = bgs.db.First(&prevState, u.ID).Error + //prevP := &prevState + var prevP repomgr.UserPrev = &prevState + if errors.Is(err, gorm.ErrRecordNotFound) { + prevP = nil + } else if err != nil { + bgs.log.Error("failed to get previous root", "err", err) + prevP = nil + } + dbPrevRootStr := "" + dbPrevSeqStr := "" + if prevP != nil { + if prevState.Seq >= evt.Seq && ((prevState.Seq - evt.Seq) < 2000) { + // ignore catchup overlap of 200 on some subscribeRepos restarts + repoCommitsResultCounter.WithLabelValues(host.Host, "dup").Inc() + return nil + } + dbPrevRootStr = prevState.Cid.CID.String() + dbPrevSeqStr = strconv.FormatInt(prevState.Seq, 10) + } + evtPrevDataStr := "" + if evt.PrevData != nil { + evtPrevDataStr = ((*cid.Cid)(evt.PrevData)).String() + } + newRootCid, err := bgs.repoman.HandleCommit(ctx, host, u, evt, prevP) + if err != nil { + bgs.inductionTraceLog.Error("commit bad", "seq", evt.Seq, "pseq", dbPrevSeqStr, "pdsHost", host.Host, "repo", evt.Repo, "prev", evtPrevDataStr, "dbprev", dbPrevRootStr, "err", err) + bgs.log.Warn("failed handling event", "err", err, "pdsHost", host.Host, "seq", evt.Seq, "repo", u.Did, "commit", evt.Commit.String()) + repoCommitsResultCounter.WithLabelValues(host.Host, "err").Inc() + return fmt.Errorf("handle user event failed: %w", err) + } else { + // store now verified new repo state + prevState.Uid = u.ID + prevState.Cid.CID = *newRootCid + prevState.Rev = evt.Rev + prevState.Seq = evt.Seq + bgs.inductionTraceLog.Info("commit ok", "seq", evt.Seq, "pdsHost", host.Host, "repo", evt.Repo, "prev", evtPrevDataStr, "dbprev", dbPrevRootStr, "nextprev", newRootCid.String()) + if prevP == nil { + err = bgs.db.Create(&prevState).Error + if err != nil { + return fmt.Errorf("failed to create previous root uid=%d: %w", u.ID, err) + } + } else { + err = bgs.db.Save(&prevState).Error + if err != nil { + return fmt.Errorf("failed to save previous root uid=%d: %w", u.ID, err) + } + } + } + + repoCommitsResultCounter.WithLabelValues(host.Host, "ok").Inc() + return nil +} + +func (bgs *BGS) handleSync(ctx context.Context, host *models.PDS, evt *comatproto.SyncSubscribeRepos_Sync) error { + // TODO: actually do something with #sync event + + // Broadcast the identity event to all consumers + err := bgs.events.AddEvent(ctx, &events.XRPCStreamEvent{ + RepoSync: evt, + }) + if err != nil { + bgs.log.Error("failed to broadcast sync event", "error", err, "did", evt.Did) + return fmt.Errorf("failed to broadcast sync event: %w", err) + } + + return nil +} + +func (bgs *BGS) handleRepoTombstone(ctx context.Context, pds *models.PDS, evt *comatproto.SyncSubscribeRepos_Tombstone) error { + u, err := bgs.lookupUserByDid(ctx, evt.Did) + if err != nil { + return err + } + + if u.PDS != pds.ID { + return fmt.Errorf("unauthoritative tombstone event from %s for %s", pds.Host, evt.Did) + } + + if err := bgs.db.Model(&User{}).Where("id = ?", u.ID).UpdateColumns(map[string]any{ + "tombstoned": true, + "handle": nil, + }).Error; err != nil { + return err + } + u.SetTombstoned(true) + + return bgs.events.AddEvent(ctx, &events.XRPCStreamEvent{ + RepoTombstone: evt, + }) +} + +func (bgs *BGS) purgeDidCache(ctx context.Context, did string) { + ati, err := syntax.ParseAtIdentifier(did) + if err != nil { + return + } + _ = bgs.didd.Purge(ctx, *ati) +} + +// createExternalUser is a mess and poorly defined +// did is the user +// host is the PDS we received this from, not necessarily the canonical PDS in the DID document +// TODO: rename? This also updates users, and 'external' is an old phrasing +func (bgs *BGS) createExternalUser(ctx context.Context, did string, host *models.PDS) (*User, error) { + ctx, span := tracer.Start(ctx, "createExternalUser") + defer span.End() + + externalUserCreationAttempts.Inc() + + bgs.log.Debug("create external user", "did", did) + pdid, err := syntax.ParseDID(did) + if err != nil { + return nil, fmt.Errorf("bad did %#v, %w", did, err) + } + ident, err := bgs.didd.LookupDID(ctx, pdid) + if err != nil { + return nil, fmt.Errorf("no ident for did %s, %w", did, err) + } + if len(ident.Services) == 0 { + return nil, fmt.Errorf("no services for did %s", did) + } + pdsService, ok := ident.Services["atproto_pds"] + if !ok { + return nil, fmt.Errorf("no atproto_pds service for did %s", did) + } + durl, err := url.Parse(pdsService.URL) + if err != nil { + return nil, fmt.Errorf("pds bad url %#v, %w", pdsService.URL, err) + } + + if strings.HasPrefix(durl.Host, "localhost:") { + durl.Scheme = "http" + } + + // TODO: the PDS's DID should also be in the service, we could use that to look up? + var peering models.PDS + if err := bgs.db.Find(&peering, "host = ?", durl.Host).Error; err != nil { + bgs.log.Error("failed to find pds", "host", durl.Host) + return nil, err + } + + ban, err := bgs.domainIsBanned(ctx, durl.Host) + if err != nil { + return nil, fmt.Errorf("failed to check pds ban status: %w", err) + } + + if ban { + return nil, fmt.Errorf("cannot create user on pds with banned domain") + } + + if peering.ID == 0 { + // why didn't we know about this PDS from requestCrawl? + // _maybe this never happens_ ? because of above peering Find ? + bgs.log.Warn("pds discovered in new user flow", "pds", durl.String(), "did", did) + pclient := &xrpc.Client{Host: durl.String()} + bgs.config.ApplyPDSClientSettings(pclient) + // TODO: the case of handling a new user on a new PDS probably requires more thought + cfg, err := comatproto.ServerDescribeServer(ctx, pclient) + if err != nil { + // TODO: failing this shouldn't halt our indexing + return nil, fmt.Errorf("failed to check unrecognized pds: %w", err) + } + + // since handles can be anything, checking against this list doesn't matter... + _ = cfg + + // TODO: could check other things, a valid response is good enough for now + peering.Host = durl.Host + peering.SSL = (durl.Scheme == "https") + peering.RateLimit = float64(bgs.slurper.DefaultPerSecondLimit) + peering.HourlyEventLimit = bgs.slurper.DefaultPerHourLimit + peering.DailyEventLimit = bgs.slurper.DefaultPerDayLimit + peering.RepoLimit = bgs.slurper.DefaultRepoLimit + + if bgs.ssl && !peering.SSL { + return nil, fmt.Errorf("did references non-ssl PDS, this is disallowed in prod: %q %q", did, pdsService.URL) + } + + if err := bgs.db.Create(&peering).Error; err != nil { + return nil, err + } + } + + if peering.ID == 0 { + panic("somehow failed to create a pds entry?") + } + + if peering.Blocked { + return nil, fmt.Errorf("refusing to create user with blocked PDS") + } + + if peering.RepoCount >= peering.RepoLimit { + return nil, fmt.Errorf("refusing to create user on PDS at max repo limit for pds %q", peering.Host) + } + + bgs.extUserLk.Lock() + defer bgs.extUserLk.Unlock() + + user, err := bgs.lookupUserByDid(ctx, did) + if err == nil { + bgs.log.Debug("lost the race to create a new user", "did", did, "existing_hand", user.Handle) + if user.PDS != peering.ID { + // User is now on a different PDS, update + err = bgs.db.Transaction(func(tx *gorm.DB) error { + res := tx.Model(User{}).Where("id = ?", user.ID).Update("pds", peering.ID) + if res.Error != nil { + return fmt.Errorf("failed to update users pds: %w", res.Error) + } + res = tx.Model(&models.PDS{}).Where("id = ? AND repo_count < repo_limit", peering.ID).Update("repo_count", gorm.Expr("repo_count + 1")) + return nil + }) + + user.PDS = peering.ID + } + return user, nil + } + + if !errors.Is(err, gorm.ErrRecordNotFound) { + return nil, err + } + + // TODO: request this users info from their server to fill out our data... + u := User{ + Did: did, + PDS: peering.ID, + ValidHandle: false, + } + + err = bgs.db.Transaction(func(tx *gorm.DB) error { + res := tx.Model(&models.PDS{}).Where("id = ? AND repo_count < repo_limit", peering.ID).Update("repo_count", gorm.Expr("repo_count + 1")) + if res.Error != nil { + return fmt.Errorf("failed to increment repo count for pds %q: %w", peering.Host, res.Error) + } + if terr := bgs.db.Create(&u).Error; terr != nil { + bgs.log.Error("failed to create user", "did", u.Did, "err", terr) + return fmt.Errorf("failed to create other pds user: %w", terr) + } + return nil + }) + if err != nil { + bgs.log.Error("user creaat and pds inc err", "err", err) + return nil, err + } + + return &u, nil +} + +func (bgs *BGS) UpdateAccountStatus(ctx context.Context, did string, status string) error { + ctx, span := tracer.Start(ctx, "UpdateAccountStatus") + defer span.End() + + span.SetAttributes( + attribute.String("did", did), + attribute.String("status", status), + ) + + u, err := bgs.lookupUserByDid(ctx, did) + if err != nil { + return err + } + + switch status { + case events.AccountStatusActive: + // Unset the PDS-specific status flags + if err := bgs.db.Model(User{}).Where("id = ?", u.ID).Update("upstream_status", events.AccountStatusActive).Error; err != nil { + return fmt.Errorf("failed to set user active status: %w", err) + } + u.SetUpstreamStatus(events.AccountStatusActive) + case events.AccountStatusDeactivated: + if err := bgs.db.Model(User{}).Where("id = ?", u.ID).Update("upstream_status", events.AccountStatusDeactivated).Error; err != nil { + return fmt.Errorf("failed to set user deactivation status: %w", err) + } + u.SetUpstreamStatus(events.AccountStatusDeactivated) + case events.AccountStatusSuspended: + if err := bgs.db.Model(User{}).Where("id = ?", u.ID).Update("upstream_status", events.AccountStatusSuspended).Error; err != nil { + return fmt.Errorf("failed to set user suspension status: %w", err) + } + u.SetUpstreamStatus(events.AccountStatusSuspended) + case events.AccountStatusTakendown: + if err := bgs.db.Model(User{}).Where("id = ?", u.ID).Update("upstream_status", events.AccountStatusTakendown).Error; err != nil { + return fmt.Errorf("failed to set user taken down status: %w", err) + } + u.SetUpstreamStatus(events.AccountStatusTakendown) + // TODO: set User takedown in db? -- bolson 2025 + case events.AccountStatusDeleted: + // TODO: tweak model to mark user deleted? -- bolson 2025 + if err := bgs.db.Model(&User{}).Where("id = ?", u.ID).UpdateColumns(map[string]any{ + "tombstoned": true, + "handle": nil, + "upstream_status": events.AccountStatusDeleted, + }).Error; err != nil { + return err + } + u.SetUpstreamStatus(events.AccountStatusDeleted) + } + + return nil +} + +func (bgs *BGS) TakeDownRepo(ctx context.Context, did string) error { + u, err := bgs.lookupUserByDid(ctx, did) + if err != nil { + return err + } + + if err := bgs.db.Model(User{}).Where("id = ?", u.ID).Update("taken_down", true).Error; err != nil { + return err + } + u.SetTakenDown(true) + + if err := bgs.events.TakeDownRepo(ctx, u.ID); err != nil { + return err + } + + return nil +} + +func (bgs *BGS) ReverseTakedown(ctx context.Context, did string) error { + u, err := bgs.lookupUserByDid(ctx, did) + if err != nil { + return err + } + + if err := bgs.db.Model(User{}).Where("id = ?", u.ID).Update("taken_down", false).Error; err != nil { + return err + } + u.SetTakenDown(false) + + return nil +} + +func (bgs *BGS) GetRepoRoot(ctx context.Context, user models.Uid) (cid.Cid, error) { + var prevState UserPreviousState + err := bgs.db.First(&prevState, user).Error + if err == nil { + return prevState.Cid.CID, nil + } else if errors.Is(err, gorm.ErrRecordNotFound) { + return cid.Cid{}, ErrUserStatusUnavailable + } else { + bgs.log.Error("user db err", "err", err) + return cid.Cid{}, fmt.Errorf("user prev db err, %w", err) + } +} diff --git a/cmd/relay/bgs/fedmgr.go b/cmd/relay/bgs/fedmgr.go new file mode 100644 index 000000000..94c5cdb42 --- /dev/null +++ b/cmd/relay/bgs/fedmgr.go @@ -0,0 +1,778 @@ +package bgs + +import ( + "context" + "errors" + "fmt" + "log/slog" + "math/rand" + "strings" + "sync" + "time" + + "github.com/RussellLuo/slidingwindow" + comatproto "github.com/bluesky-social/indigo/api/atproto" + "github.com/bluesky-social/indigo/cmd/relay/events" + "github.com/bluesky-social/indigo/cmd/relay/events/schedulers/parallel" + "github.com/bluesky-social/indigo/cmd/relay/models" + + "github.com/gorilla/websocket" + pq "github.com/lib/pq" + "gorm.io/gorm" +) + +type IndexCallback func(context.Context, *models.PDS, *events.XRPCStreamEvent) error + +type Slurper struct { + cb IndexCallback + + db *gorm.DB + + lk sync.Mutex + active map[string]*activeSub + + LimitMux sync.RWMutex + Limiters map[uint]*Limiters + DefaultPerSecondLimit int64 + DefaultPerHourLimit int64 + DefaultPerDayLimit int64 + + DefaultRepoLimit int64 + ConcurrencyPerPDS int64 + MaxQueuePerPDS int64 + + NewPDSPerDayLimiter *slidingwindow.Limiter + + newSubsDisabled bool + trustedDomains []string + + shutdownChan chan bool + shutdownResult chan []error + + ssl bool + + log *slog.Logger +} + +type Limiters struct { + PerSecond *slidingwindow.Limiter + PerHour *slidingwindow.Limiter + PerDay *slidingwindow.Limiter +} + +type SlurperOptions struct { + SSL bool + DefaultPerSecondLimit int64 + DefaultPerHourLimit int64 + DefaultPerDayLimit int64 + DefaultRepoLimit int64 + ConcurrencyPerPDS int64 + MaxQueuePerPDS int64 + + Logger *slog.Logger +} + +func DefaultSlurperOptions() *SlurperOptions { + return &SlurperOptions{ + SSL: false, + DefaultPerSecondLimit: 50, + DefaultPerHourLimit: 2500, + DefaultPerDayLimit: 20_000, + DefaultRepoLimit: 100, + ConcurrencyPerPDS: 100, + MaxQueuePerPDS: 1_000, + + Logger: slog.Default(), + } +} + +type activeSub struct { + pds *models.PDS + lk sync.RWMutex + ctx context.Context + cancel func() +} + +func (sub *activeSub) updateCursor(curs int64) { + sub.lk.Lock() + defer sub.lk.Unlock() + sub.pds.Cursor = curs +} + +func NewSlurper(db *gorm.DB, cb IndexCallback, opts *SlurperOptions) (*Slurper, error) { + if opts == nil { + opts = DefaultSlurperOptions() + } + err := db.AutoMigrate(&SlurpConfig{}) + if err != nil { + return nil, err + } + s := &Slurper{ + cb: cb, + db: db, + active: make(map[string]*activeSub), + Limiters: make(map[uint]*Limiters), + DefaultPerSecondLimit: opts.DefaultPerSecondLimit, + DefaultPerHourLimit: opts.DefaultPerHourLimit, + DefaultPerDayLimit: opts.DefaultPerDayLimit, + DefaultRepoLimit: opts.DefaultRepoLimit, + ConcurrencyPerPDS: opts.ConcurrencyPerPDS, + MaxQueuePerPDS: opts.MaxQueuePerPDS, + ssl: opts.SSL, + shutdownChan: make(chan bool), + shutdownResult: make(chan []error), + log: opts.Logger, + } + if err := s.loadConfig(); err != nil { + return nil, err + } + + // Start a goroutine to flush cursors to the DB every 30s + go func() { + for { + select { + case <-s.shutdownChan: + s.log.Info("flushing PDS cursors on shutdown") + ctx := context.Background() + //ctx, span := otel.Tracer("feedmgr").Start(ctx, "CursorFlusherShutdown") + //defer span.End() + var errs []error + if errs = s.flushCursors(ctx); len(errs) > 0 { + for _, err := range errs { + s.log.Error("failed to flush cursors on shutdown", "err", err) + } + } + s.log.Info("done flushing PDS cursors on shutdown") + s.shutdownResult <- errs + return + case <-time.After(time.Second * 10): + s.log.Debug("flushing PDS cursors") + ctx := context.Background() + //ctx, span := otel.Tracer("feedmgr").Start(ctx, "CursorFlusher") + //defer span.End() + if errs := s.flushCursors(ctx); len(errs) > 0 { + for _, err := range errs { + s.log.Error("failed to flush cursors", "err", err) + } + } + s.log.Debug("done flushing PDS cursors") + } + } + }() + + return s, nil +} + +func windowFunc() (slidingwindow.Window, slidingwindow.StopFunc) { + return slidingwindow.NewLocalWindow() +} + +func (s *Slurper) GetLimiters(pdsID uint) *Limiters { + s.LimitMux.RLock() + defer s.LimitMux.RUnlock() + return s.Limiters[pdsID] +} + +func (s *Slurper) GetOrCreateLimiters(pdsID uint, perSecLimit int64, perHourLimit int64, perDayLimit int64) *Limiters { + s.LimitMux.RLock() + defer s.LimitMux.RUnlock() + lim, ok := s.Limiters[pdsID] + if !ok { + perSec, _ := slidingwindow.NewLimiter(time.Second, perSecLimit, windowFunc) + perHour, _ := slidingwindow.NewLimiter(time.Hour, perHourLimit, windowFunc) + perDay, _ := slidingwindow.NewLimiter(time.Hour*24, perDayLimit, windowFunc) + lim = &Limiters{ + PerSecond: perSec, + PerHour: perHour, + PerDay: perDay, + } + s.Limiters[pdsID] = lim + } + + return lim +} + +func (s *Slurper) SetLimits(pdsID uint, perSecLimit int64, perHourLimit int64, perDayLimit int64) { + s.LimitMux.Lock() + defer s.LimitMux.Unlock() + lim, ok := s.Limiters[pdsID] + if !ok { + perSec, _ := slidingwindow.NewLimiter(time.Second, perSecLimit, windowFunc) + perHour, _ := slidingwindow.NewLimiter(time.Hour, perHourLimit, windowFunc) + perDay, _ := slidingwindow.NewLimiter(time.Hour*24, perDayLimit, windowFunc) + lim = &Limiters{ + PerSecond: perSec, + PerHour: perHour, + PerDay: perDay, + } + s.Limiters[pdsID] = lim + } + + lim.PerSecond.SetLimit(perSecLimit) + lim.PerHour.SetLimit(perHourLimit) + lim.PerDay.SetLimit(perDayLimit) +} + +// Shutdown shuts down the slurper +func (s *Slurper) Shutdown() []error { + s.shutdownChan <- true + s.log.Info("waiting for slurper shutdown") + errs := <-s.shutdownResult + if len(errs) > 0 { + for _, err := range errs { + s.log.Error("shutdown error", "err", err) + } + } + s.log.Info("slurper shutdown complete") + return errs +} + +func (s *Slurper) loadConfig() error { + var sc SlurpConfig + if err := s.db.Find(&sc).Error; err != nil { + return err + } + + if sc.ID == 0 { + if err := s.db.Create(&SlurpConfig{}).Error; err != nil { + return err + } + } + + s.newSubsDisabled = sc.NewSubsDisabled + s.trustedDomains = sc.TrustedDomains + + s.NewPDSPerDayLimiter, _ = slidingwindow.NewLimiter(time.Hour*24, sc.NewPDSPerDayLimit, windowFunc) + + return nil +} + +type SlurpConfig struct { + gorm.Model + + NewSubsDisabled bool + TrustedDomains pq.StringArray `gorm:"type:text[]"` + NewPDSPerDayLimit int64 +} + +func (s *Slurper) SetNewSubsDisabled(dis bool) error { + s.lk.Lock() + defer s.lk.Unlock() + + if err := s.db.Model(SlurpConfig{}).Where("id = 1").Update("new_subs_disabled", dis).Error; err != nil { + return err + } + + s.newSubsDisabled = dis + return nil +} + +func (s *Slurper) GetNewSubsDisabledState() bool { + s.lk.Lock() + defer s.lk.Unlock() + return s.newSubsDisabled +} + +func (s *Slurper) SetNewPDSPerDayLimit(limit int64) error { + s.lk.Lock() + defer s.lk.Unlock() + + if err := s.db.Model(SlurpConfig{}).Where("id = 1").Update("new_pds_per_day_limit", limit).Error; err != nil { + return err + } + + s.NewPDSPerDayLimiter.SetLimit(limit) + return nil +} + +func (s *Slurper) GetNewPDSPerDayLimit() int64 { + s.lk.Lock() + defer s.lk.Unlock() + return s.NewPDSPerDayLimiter.Limit() +} + +func (s *Slurper) AddTrustedDomain(domain string) error { + s.lk.Lock() + defer s.lk.Unlock() + + if err := s.db.Model(SlurpConfig{}).Where("id = 1").Update("trusted_domains", gorm.Expr("array_append(trusted_domains, ?)", domain)).Error; err != nil { + return err + } + + s.trustedDomains = append(s.trustedDomains, domain) + return nil +} + +func (s *Slurper) RemoveTrustedDomain(domain string) error { + s.lk.Lock() + defer s.lk.Unlock() + + if err := s.db.Model(SlurpConfig{}).Where("id = 1").Update("trusted_domains", gorm.Expr("array_remove(trusted_domains, ?)", domain)).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil + } + return err + } + + for i, d := range s.trustedDomains { + if d == domain { + s.trustedDomains = append(s.trustedDomains[:i], s.trustedDomains[i+1:]...) + break + } + } + + return nil +} + +func (s *Slurper) SetTrustedDomains(domains []string) error { + s.lk.Lock() + defer s.lk.Unlock() + + if err := s.db.Model(SlurpConfig{}).Where("id = 1").Update("trusted_domains", domains).Error; err != nil { + return err + } + + s.trustedDomains = domains + return nil +} + +func (s *Slurper) GetTrustedDomains() []string { + s.lk.Lock() + defer s.lk.Unlock() + return s.trustedDomains +} + +var ErrNewSubsDisabled = fmt.Errorf("new subscriptions temporarily disabled") + +// Checks whether a host is allowed to be subscribed to +// must be called with the slurper lock held +func (s *Slurper) canSlurpHost(host string) bool { + // Check if we're over the limit for new PDSs today + if !s.NewPDSPerDayLimiter.Allow() { + return false + } + + // Check if the host is a trusted domain + for _, d := range s.trustedDomains { + // If the domain starts with a *., it's a wildcard + if strings.HasPrefix(d, "*.") { + // Cut off the * so we have .domain.com + if strings.HasSuffix(host, strings.TrimPrefix(d, "*")) { + return true + } + } else { + if host == d { + return true + } + } + } + + return !s.newSubsDisabled +} + +func (s *Slurper) SubscribeToPds(ctx context.Context, host string, reg bool, adminOverride bool, rateOverrides *PDSRates) error { + // TODO: for performance, lock on the hostname instead of global + s.lk.Lock() + defer s.lk.Unlock() + + _, ok := s.active[host] + if ok { + return nil + } + + var peering models.PDS + if err := s.db.Find(&peering, "host = ?", host).Error; err != nil { + return err + } + + if peering.Blocked { + return fmt.Errorf("cannot subscribe to blocked pds") + } + + newHost := false + + if peering.ID == 0 { + if !adminOverride && !s.canSlurpHost(host) { + return ErrNewSubsDisabled + } + // New PDS! + npds := models.PDS{ + Host: host, + SSL: s.ssl, + Registered: reg, + RateLimit: float64(s.DefaultPerSecondLimit), + HourlyEventLimit: s.DefaultPerHourLimit, + DailyEventLimit: s.DefaultPerDayLimit, + RepoLimit: s.DefaultRepoLimit, + } + if rateOverrides != nil { + npds.RateLimit = float64(rateOverrides.PerSecond) + npds.HourlyEventLimit = rateOverrides.PerHour + npds.DailyEventLimit = rateOverrides.PerDay + npds.RepoLimit = rateOverrides.RepoLimit + } + if err := s.db.Create(&npds).Error; err != nil { + return err + } + + newHost = true + peering = npds + } + + if !peering.Registered && reg { + peering.Registered = true + if err := s.db.Model(models.PDS{}).Where("id = ?", peering.ID).Update("registered", true).Error; err != nil { + return err + } + } + + ctx, cancel := context.WithCancel(context.Background()) + sub := activeSub{ + pds: &peering, + ctx: ctx, + cancel: cancel, + } + s.active[host] = &sub + + s.GetOrCreateLimiters(peering.ID, int64(peering.RateLimit), peering.HourlyEventLimit, peering.DailyEventLimit) + + go s.subscribeWithRedialer(ctx, &peering, &sub, newHost) + + return nil +} + +func (s *Slurper) RestartAll() error { + s.lk.Lock() + defer s.lk.Unlock() + + var all []models.PDS + if err := s.db.Find(&all, "registered = true AND blocked = false").Error; err != nil { + return err + } + + for _, pds := range all { + pds := pds + + ctx, cancel := context.WithCancel(context.Background()) + sub := activeSub{ + pds: &pds, + ctx: ctx, + cancel: cancel, + } + s.active[pds.Host] = &sub + + // Check if we've already got a limiter for this PDS + s.GetOrCreateLimiters(pds.ID, int64(pds.RateLimit), pds.HourlyEventLimit, pds.DailyEventLimit) + go s.subscribeWithRedialer(ctx, &pds, &sub, false) + } + + return nil +} + +func (s *Slurper) subscribeWithRedialer(ctx context.Context, host *models.PDS, sub *activeSub, newHost bool) { + defer func() { + s.lk.Lock() + defer s.lk.Unlock() + + delete(s.active, host.Host) + }() + + d := websocket.Dialer{ + HandshakeTimeout: time.Second * 5, + } + + protocol := "ws" + if s.ssl { + protocol = "wss" + } + + // Special case `.host.bsky.network` PDSs to rewind cursor by 200 events to smooth over unclean shutdowns + if strings.HasSuffix(host.Host, ".host.bsky.network") && host.Cursor > 200 { + host.Cursor -= 200 + } + + cursor := host.Cursor + + connectedInbound.Inc() + defer connectedInbound.Dec() + // TODO:? maybe keep a gauge of 'in retry backoff' sources? + + var backoff int + for { + select { + case <-ctx.Done(): + return + default: + } + + var url string + if newHost { + url = fmt.Sprintf("%s://%s/xrpc/com.atproto.sync.subscribeRepos", protocol, host.Host) + } else { + url = fmt.Sprintf("%s://%s/xrpc/com.atproto.sync.subscribeRepos?cursor=%d", protocol, host.Host, cursor) + } + con, res, err := d.DialContext(ctx, url, nil) + if err != nil { + s.log.Warn("dialing failed", "pdsHost", host.Host, "err", err, "backoff", backoff) + time.Sleep(sleepForBackoff(backoff)) + backoff++ + + if backoff > 15 { + s.log.Warn("pds does not appear to be online, disabling for now", "pdsHost", host.Host) + if err := s.db.Model(&models.PDS{}).Where("id = ?", host.ID).Update("registered", false).Error; err != nil { + s.log.Error("failed to unregister failing pds", "err", err) + } + + return + } + + continue + } + + s.log.Info("event subscription response", "code", res.StatusCode, "url", url) + + curCursor := cursor + if err := s.handleConnection(ctx, host, con, &cursor, sub); err != nil { + if errors.Is(err, ErrTimeoutShutdown) { + s.log.Info("shutting down pds subscription after timeout", "host", host.Host, "time", EventsTimeout) + return + } + s.log.Warn("connection to failed", "host", host.Host, "err", err) + // TODO: measure the last N connection error times and if they're coming too fast reconnect slower or don't reconnect and wait for requestCrawl + } + + if cursor > curCursor { + backoff = 0 + } + } +} + +func sleepForBackoff(b int) time.Duration { + if b == 0 { + return 0 + } + + if b < 10 { + return (time.Duration(b) * 2) + (time.Millisecond * time.Duration(rand.Intn(1000))) + } + + return time.Second * 30 +} + +var ErrTimeoutShutdown = fmt.Errorf("timed out waiting for new events") + +var EventsTimeout = time.Minute + +func (s *Slurper) handleConnection(ctx context.Context, host *models.PDS, con *websocket.Conn, lastCursor *int64, sub *activeSub) error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + rsc := &events.RepoStreamCallbacks{ + RepoCommit: func(evt *comatproto.SyncSubscribeRepos_Commit) error { + s.log.Debug("got remote repo event", "pdsHost", host.Host, "repo", evt.Repo, "seq", evt.Seq) + if err := s.cb(context.TODO(), host, &events.XRPCStreamEvent{ + RepoCommit: evt, + }); err != nil { + s.log.Error("failed handling event", "host", host.Host, "seq", evt.Seq, "err", err) + } + *lastCursor = evt.Seq + + sub.updateCursor(*lastCursor) + + return nil + }, + RepoSync: func(evt *comatproto.SyncSubscribeRepos_Sync) error { + s.log.Debug("got remote repo event", "pdsHost", host.Host, "repo", evt.Did, "seq", evt.Seq) + if err := s.cb(context.TODO(), host, &events.XRPCStreamEvent{ + RepoSync: evt, + }); err != nil { + s.log.Error("failed handling event", "host", host.Host, "seq", evt.Seq, "err", err) + } + *lastCursor = evt.Seq + + sub.updateCursor(*lastCursor) + + return nil + }, + RepoHandle: func(evt *comatproto.SyncSubscribeRepos_Handle) error { + s.log.Debug("got remote handle update event", "pdsHost", host.Host, "did", evt.Did, "handle", evt.Handle) + if err := s.cb(context.TODO(), host, &events.XRPCStreamEvent{ + RepoHandle: evt, + }); err != nil { + s.log.Error("failed handling event", "host", host.Host, "seq", evt.Seq, "err", err) + } + *lastCursor = evt.Seq + + sub.updateCursor(*lastCursor) + + return nil + }, + RepoMigrate: func(evt *comatproto.SyncSubscribeRepos_Migrate) error { + s.log.Debug("got remote repo migrate event", "pdsHost", host.Host, "did", evt.Did, "migrateTo", evt.MigrateTo) + if err := s.cb(context.TODO(), host, &events.XRPCStreamEvent{ + RepoMigrate: evt, + }); err != nil { + s.log.Error("failed handling event", "host", host.Host, "seq", evt.Seq, "err", err) + } + *lastCursor = evt.Seq + + sub.updateCursor(*lastCursor) + + return nil + }, + RepoTombstone: func(evt *comatproto.SyncSubscribeRepos_Tombstone) error { + s.log.Debug("got remote repo tombstone event", "pdsHost", host.Host, "did", evt.Did) + if err := s.cb(context.TODO(), host, &events.XRPCStreamEvent{ + RepoTombstone: evt, + }); err != nil { + s.log.Error("failed handling event", "host", host.Host, "seq", evt.Seq, "err", err) + } + *lastCursor = evt.Seq + + sub.updateCursor(*lastCursor) + + return nil + }, + RepoInfo: func(info *comatproto.SyncSubscribeRepos_Info) error { + s.log.Debug("info event", "name", info.Name, "message", info.Message, "pdsHost", host.Host) + return nil + }, + RepoIdentity: func(ident *comatproto.SyncSubscribeRepos_Identity) error { + s.log.Debug("identity event", "did", ident.Did) + if err := s.cb(context.TODO(), host, &events.XRPCStreamEvent{ + RepoIdentity: ident, + }); err != nil { + s.log.Error("failed handling event", "host", host.Host, "seq", ident.Seq, "err", err) + } + *lastCursor = ident.Seq + + sub.updateCursor(*lastCursor) + + return nil + }, + RepoAccount: func(acct *comatproto.SyncSubscribeRepos_Account) error { + s.log.Debug("account event", "did", acct.Did, "status", acct.Status) + if err := s.cb(context.TODO(), host, &events.XRPCStreamEvent{ + RepoAccount: acct, + }); err != nil { + s.log.Error("failed handling event", "host", host.Host, "seq", acct.Seq, "err", err) + } + *lastCursor = acct.Seq + + sub.updateCursor(*lastCursor) + + return nil + }, + // TODO: all the other event types (handle change, migration, etc) + Error: func(errf *events.ErrorFrame) error { + switch errf.Error { + case "FutureCursor": + // if we get a FutureCursor frame, reset our sequence number for this host + if err := s.db.Table("pds").Where("id = ?", host.ID).Update("cursor", 0).Error; err != nil { + return err + } + + *lastCursor = 0 + return fmt.Errorf("got FutureCursor frame, reset cursor tracking for host") + default: + return fmt.Errorf("error frame: %s: %s", errf.Error, errf.Message) + } + }, + } + + lims := s.GetOrCreateLimiters(host.ID, int64(host.RateLimit), host.HourlyEventLimit, host.DailyEventLimit) + + limiters := []*slidingwindow.Limiter{ + lims.PerSecond, + lims.PerHour, + lims.PerDay, + } + + instrumentedRSC := events.NewInstrumentedRepoStreamCallbacks(limiters, rsc.EventHandler) + + pool := parallel.NewScheduler( + 100, + 1_000, + con.RemoteAddr().String(), + instrumentedRSC.EventHandler, + ) + return events.HandleRepoStream(ctx, con, pool, nil) +} + +type cursorSnapshot struct { + id uint + cursor int64 +} + +// flushCursors updates the PDS cursors in the DB for all active subscriptions +func (s *Slurper) flushCursors(ctx context.Context) []error { + start := time.Now() + //ctx, span := otel.Tracer("feedmgr").Start(ctx, "flushCursors") + //defer span.End() + + var cursors []cursorSnapshot + + s.lk.Lock() + // Iterate over active subs and copy the current cursor + for _, sub := range s.active { + sub.lk.RLock() + cursors = append(cursors, cursorSnapshot{ + id: sub.pds.ID, + cursor: sub.pds.Cursor, + }) + sub.lk.RUnlock() + } + s.lk.Unlock() + + errs := []error{} + okcount := 0 + + tx := s.db.WithContext(ctx).Begin() + for _, cursor := range cursors { + if err := tx.WithContext(ctx).Model(models.PDS{}).Where("id = ?", cursor.id).UpdateColumn("cursor", cursor.cursor).Error; err != nil { + errs = append(errs, err) + } else { + okcount++ + } + } + if err := tx.WithContext(ctx).Commit().Error; err != nil { + errs = append(errs, err) + } + dt := time.Since(start) + s.log.Info("flushCursors", "dt", dt, "ok", okcount, "errs", len(errs)) + + return errs +} + +func (s *Slurper) GetActiveList() []string { + s.lk.Lock() + defer s.lk.Unlock() + var out []string + for k := range s.active { + out = append(out, k) + } + + return out +} + +var ErrNoActiveConnection = fmt.Errorf("no active connection to host") + +func (s *Slurper) KillUpstreamConnection(host string, block bool) error { + s.lk.Lock() + defer s.lk.Unlock() + + ac, ok := s.active[host] + if !ok { + return fmt.Errorf("killing connection %q: %w", host, ErrNoActiveConnection) + } + ac.cancel() + // cleanup in the run thread subscribeWithRedialer() will delete(s.active, host) + + if block { + if err := s.db.Model(models.PDS{}).Where("id = ?", ac.pds.ID).UpdateColumn("blocked", true).Error; err != nil { + return fmt.Errorf("failed to set host as blocked: %w", err) + } + } + + return nil +} diff --git a/cmd/relay/bgs/handlers.go b/cmd/relay/bgs/handlers.go new file mode 100644 index 000000000..308c142a8 --- /dev/null +++ b/cmd/relay/bgs/handlers.go @@ -0,0 +1,207 @@ +package bgs + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + "strings" + + atproto "github.com/bluesky-social/indigo/api/atproto" + comatprototypes "github.com/bluesky-social/indigo/api/atproto" + "github.com/bluesky-social/indigo/cmd/relay/events" + "gorm.io/gorm" + + "github.com/bluesky-social/indigo/xrpc" + "github.com/labstack/echo/v4" +) + +func (s *BGS) handleComAtprotoSyncRequestCrawl(ctx context.Context, body *comatprototypes.SyncRequestCrawl_Input) error { + host := body.Hostname + if host == "" { + return echo.NewHTTPError(http.StatusBadRequest, "must pass hostname") + } + + if !strings.HasPrefix(host, "http://") && !strings.HasPrefix(host, "https://") { + if s.ssl { + host = "https://" + host + } else { + host = "http://" + host + } + } + + u, err := url.Parse(host) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, "failed to parse hostname") + } + + if u.Scheme == "http" && s.ssl { + return echo.NewHTTPError(http.StatusBadRequest, "this server requires https") + } + + if u.Scheme == "https" && !s.ssl { + return echo.NewHTTPError(http.StatusBadRequest, "this server does not support https") + } + + if u.Path != "" { + return echo.NewHTTPError(http.StatusBadRequest, "must pass hostname without path") + } + + if u.Query().Encode() != "" { + return echo.NewHTTPError(http.StatusBadRequest, "must pass hostname without query") + } + + host = u.Host // potentially hostname:port + + banned, err := s.domainIsBanned(ctx, host) + if banned { + return echo.NewHTTPError(http.StatusUnauthorized, "domain is banned") + } + + s.log.Warn("TODO: better host validation for crawl requests") + + clientHost := fmt.Sprintf("%s://%s", u.Scheme, host) + + c := &xrpc.Client{ + Host: clientHost, + Client: http.DefaultClient, // not using the client that auto-retries + } + + desc, err := atproto.ServerDescribeServer(ctx, c) + if err != nil { + errMsg := fmt.Sprintf("requested host (%s) failed to respond to describe request", clientHost) + return echo.NewHTTPError(http.StatusBadRequest, errMsg) + } + + // Maybe we could do something with this response later + _ = desc + + if len(s.nextCrawlers) != 0 { + blob, err := json.Marshal(body) + if err != nil { + s.log.Warn("could not forward requestCrawl, json err", "err", err) + } else { + go func(bodyBlob []byte) { + for _, rpu := range s.nextCrawlers { + pu := rpu.JoinPath("/xrpc/com.atproto.sync.requestCrawl") + response, err := s.httpClient.Post(pu.String(), "application/json", bytes.NewReader(bodyBlob)) + if response != nil && response.Body != nil { + response.Body.Close() + } + if err != nil || response == nil { + s.log.Warn("requestCrawl forward failed", "host", rpu, "err", err) + } else if response.StatusCode != http.StatusOK { + s.log.Warn("requestCrawl forward failed", "host", rpu, "status", response.Status) + } else { + s.log.Info("requestCrawl forward successful", "host", rpu) + } + } + }(blob) + } + } + + return s.slurper.SubscribeToPds(ctx, host, true, false, nil) +} + +func (s *BGS) handleComAtprotoSyncListRepos(ctx context.Context, cursor int64, limit int) (*comatprototypes.SyncListRepos_Output, error) { + // Filter out tombstoned, taken down, and deactivated accounts + q := fmt.Sprintf("id > ? AND NOT tombstoned AND NOT taken_down AND (upstream_status is NULL OR (upstream_status != '%s' AND upstream_status != '%s' AND upstream_status != '%s'))", + events.AccountStatusDeactivated, events.AccountStatusSuspended, events.AccountStatusTakendown) + + // Load the users + users := []*User{} + if err := s.db.Model(&User{}).Where(q, cursor).Order("id").Limit(limit).Find(&users).Error; err != nil { + if err == gorm.ErrRecordNotFound { + return &comatprototypes.SyncListRepos_Output{}, nil + } + s.log.Error("failed to query users", "err", err) + return nil, echo.NewHTTPError(http.StatusInternalServerError, "failed to query users") + } + + if len(users) == 0 { + // resp.Repos is an explicit empty array, not just 'nil' + return &comatprototypes.SyncListRepos_Output{ + Repos: []*comatprototypes.SyncListRepos_Repo{}, + }, nil + } + + resp := &comatprototypes.SyncListRepos_Output{ + Repos: make([]*comatprototypes.SyncListRepos_Repo, len(users)), + } + + // Fetch the repo roots for each user + for i := range users { + user := users[i] + + root, err := s.GetRepoRoot(ctx, user.ID) + if err != nil { + s.log.Error("failed to get repo root", "err", err, "did", user.Did) + return nil, echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to get repo root for (%s): %v", user.Did, err.Error())) + } + + resp.Repos[i] = &comatprototypes.SyncListRepos_Repo{ + Did: user.Did, + Head: root.String(), + } + } + + // If this is not the last page, set the cursor + if len(users) >= limit && len(users) > 1 { + nextCursor := fmt.Sprintf("%d", users[len(users)-1].ID) + resp.Cursor = &nextCursor + } + + return resp, nil +} + +var ErrUserStatusUnavailable = errors.New("user status unavailable") + +func (s *BGS) handleComAtprotoSyncGetLatestCommit(ctx context.Context, did string) (*comatprototypes.SyncGetLatestCommit_Output, error) { + u, err := s.lookupUserByDid(ctx, did) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, echo.NewHTTPError(http.StatusNotFound, "user not found") + } + return nil, echo.NewHTTPError(http.StatusInternalServerError, "failed to lookup user") + } + + if u.GetTombstoned() { + return nil, fmt.Errorf("account was deleted") + } + + if u.GetTakenDown() { + return nil, fmt.Errorf("account was taken down by the Relay") + } + + ustatus := u.GetUpstreamStatus() + if ustatus == events.AccountStatusTakendown { + return nil, fmt.Errorf("account was taken down by its PDS") + } + + if ustatus == events.AccountStatusDeactivated { + return nil, fmt.Errorf("account is temporarily deactivated") + } + + if ustatus == events.AccountStatusSuspended { + return nil, fmt.Errorf("account is suspended by its PDS") + } + + var prevState UserPreviousState + err = s.db.First(&prevState, u.ID).Error + if err == nil { + // okay! + } else if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, ErrUserStatusUnavailable + } else { + s.log.Error("user db err", "err", err) + return nil, fmt.Errorf("user prev db err, %w", err) + } + + return &comatprototypes.SyncGetLatestCommit_Output{ + Cid: prevState.Cid.CID.String(), + Rev: prevState.Rev, + }, nil +} diff --git a/cmd/relay/bgs/metrics.go b/cmd/relay/bgs/metrics.go new file mode 100644 index 000000000..da5d6a341 --- /dev/null +++ b/cmd/relay/bgs/metrics.go @@ -0,0 +1,157 @@ +package bgs + +import ( + "errors" + "net/http" + "strconv" + "time" + + "github.com/labstack/echo/v4" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +var eventsReceivedCounter = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "events_received_counter", + Help: "The total number of events received", +}, []string{"pds"}) + +var eventsHandleDuration = promauto.NewHistogramVec(prometheus.HistogramOpts{ + Name: "events_handle_duration", + Help: "A histogram of handleFedEvent latencies", + Buckets: prometheus.ExponentialBuckets(0.001, 2, 15), +}, []string{"pds"}) + +var repoCommitsReceivedCounter = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "repo_commits_received_counter", + Help: "The total number of events received", +}, []string{"pds"}) + +var repoCommitsResultCounter = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "repo_commits_result_counter", + Help: "The results of commit events received", +}, []string{"pds", "status"}) + +var rebasesCounter = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "event_rebases", + Help: "The total number of rebase events received", +}, []string{"pds"}) + +var eventsSentCounter = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "events_sent_counter", + Help: "The total number of events sent to consumers", +}, []string{"remote_addr", "user_agent"}) + +var externalUserCreationAttempts = promauto.NewCounter(prometheus.CounterOpts{ + Name: "bgs_external_user_creation_attempts", + Help: "The total number of external users created", +}) + +var connectedInbound = promauto.NewGauge(prometheus.GaugeOpts{ + Name: "bgs_connected_inbound", + Help: "Number of inbound firehoses we are consuming", +}) + +var newUsersDiscovered = promauto.NewCounter(prometheus.CounterOpts{ + Name: "bgs_new_users_discovered", + Help: "The total number of new users discovered directly from the firehose (not from refs)", +}) + +var reqSz = promauto.NewHistogramVec(prometheus.HistogramOpts{ + Name: "http_request_size_bytes", + Help: "A histogram of request sizes for requests.", + Buckets: prometheus.ExponentialBuckets(100, 10, 8), +}, []string{"code", "method", "path"}) + +var reqDur = promauto.NewHistogramVec(prometheus.HistogramOpts{ + Name: "http_request_duration_seconds", + Help: "A histogram of latencies for requests.", + Buckets: prometheus.ExponentialBuckets(0.001, 2, 15), +}, []string{"code", "method", "path"}) + +var reqCnt = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "http_requests_total", + Help: "A counter for requests to the wrapped handler.", +}, []string{"code", "method", "path"}) + +var resSz = promauto.NewHistogramVec(prometheus.HistogramOpts{ + Name: "http_response_size_bytes", + Help: "A histogram of response sizes for requests.", + Buckets: prometheus.ExponentialBuckets(100, 10, 8), +}, []string{"code", "method", "path"}) + +//var userLookupDuration = promauto.NewHistogram(prometheus.HistogramOpts{ +// Name: "relay_user_lookup_duration", +// Help: "A histogram of user lookup latencies", +// Buckets: prometheus.ExponentialBuckets(0.001, 2, 15), +//}) + +var newUserDiscoveryDuration = promauto.NewHistogram(prometheus.HistogramOpts{ + Name: "relay_new_user_discovery_duration", + Help: "A histogram of new user discovery latencies", + Buckets: prometheus.ExponentialBuckets(0.001, 2, 15), +}) + +// MetricsMiddleware defines handler function for metrics middleware +func MetricsMiddleware(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + path := c.Path() + if path == "/metrics" || path == "/_health" { + return next(c) + } + + start := time.Now() + requestSize := computeApproximateRequestSize(c.Request()) + + err := next(c) + + status := c.Response().Status + if err != nil { + var httpError *echo.HTTPError + if errors.As(err, &httpError) { + status = httpError.Code + } + if status == 0 || status == http.StatusOK { + status = http.StatusInternalServerError + } + } + + elapsed := float64(time.Since(start)) / float64(time.Second) + + statusStr := strconv.Itoa(status) + method := c.Request().Method + + responseSize := float64(c.Response().Size) + + reqDur.WithLabelValues(statusStr, method, path).Observe(elapsed) + reqCnt.WithLabelValues(statusStr, method, path).Inc() + reqSz.WithLabelValues(statusStr, method, path).Observe(float64(requestSize)) + resSz.WithLabelValues(statusStr, method, path).Observe(responseSize) + + return err + } +} + +func computeApproximateRequestSize(r *http.Request) int { + s := 0 + if r.URL != nil { + s = len(r.URL.Path) + } + + s += len(r.Method) + s += len(r.Proto) + for name, values := range r.Header { + s += len(name) + for _, value := range values { + s += len(value) + } + } + s += len(r.Host) + + // N.B. r.Form and r.MultipartForm are assumed to be included in r.URL. + + if r.ContentLength != -1 { + s += int(r.ContentLength) + } + return s +} diff --git a/cmd/relay/bgs/models.go b/cmd/relay/bgs/models.go new file mode 100644 index 000000000..a3bcaad72 --- /dev/null +++ b/cmd/relay/bgs/models.go @@ -0,0 +1,8 @@ +package bgs + +import "gorm.io/gorm" + +type DomainBan struct { + gorm.Model + Domain string +} diff --git a/cmd/relay/bgs/stubs.go b/cmd/relay/bgs/stubs.go new file mode 100644 index 000000000..1f1a2bbe2 --- /dev/null +++ b/cmd/relay/bgs/stubs.go @@ -0,0 +1,142 @@ +package bgs + +import ( + "errors" + "fmt" + "gorm.io/gorm" + "net/http" + "strconv" + + comatprototypes "github.com/bluesky-social/indigo/api/atproto" + "github.com/bluesky-social/indigo/atproto/syntax" + "github.com/labstack/echo/v4" + "go.opentelemetry.io/otel" +) + +type XRPCError struct { + Message string `json:"message"` +} + +func (s *BGS) RegisterHandlersAppBsky(e *echo.Echo) error { + return nil +} + +func (s *BGS) RegisterHandlersComAtproto(e *echo.Echo) error { + e.GET("/xrpc/com.atproto.sync.getLatestCommit", s.HandleComAtprotoSyncGetLatestCommit) + e.GET("/xrpc/com.atproto.sync.listRepos", s.HandleComAtprotoSyncListRepos) + e.POST("/xrpc/com.atproto.sync.requestCrawl", s.HandleComAtprotoSyncRequestCrawl) + return nil +} + +func (s *BGS) HandleComAtprotoSyncGetLatestCommit(c echo.Context) error { + ctx, span := otel.Tracer("server").Start(c.Request().Context(), "HandleComAtprotoSyncGetLatestCommit") + defer span.End() + did := c.QueryParam("did") + + _, err := syntax.ParseDID(did) + if err != nil { + return c.JSON(http.StatusBadRequest, XRPCError{Message: fmt.Sprintf("invalid did: %s", did)}) + } + + var out *comatprototypes.SyncGetLatestCommit_Output + var handleErr error + // func (s *BGS) handleComAtprotoSyncGetLatestCommit(ctx context.Context,did string) (*comatprototypes.SyncGetLatestCommit_Output, error) + out, handleErr = s.handleComAtprotoSyncGetLatestCommit(ctx, did) + if handleErr != nil { + return handleErr + } + return c.JSON(200, out) +} + +func (s *BGS) HandleComAtprotoSyncListRepos(c echo.Context) error { + ctx, span := otel.Tracer("server").Start(c.Request().Context(), "HandleComAtprotoSyncListRepos") + defer span.End() + + cursorQuery := c.QueryParam("cursor") + limitQuery := c.QueryParam("limit") + + var err error + + limit := 500 + if limitQuery != "" { + limit, err = strconv.Atoi(limitQuery) + if err != nil || limit < 1 || limit > 1000 { + return c.JSON(http.StatusBadRequest, XRPCError{Message: fmt.Sprintf("invalid limit: %s", limitQuery)}) + } + } + + cursor := int64(0) + if cursorQuery != "" { + cursor, err = strconv.ParseInt(cursorQuery, 10, 64) + if err != nil || cursor < 0 { + return c.JSON(http.StatusBadRequest, XRPCError{Message: fmt.Sprintf("invalid cursor: %s", cursorQuery)}) + } + } + + out, handleErr := s.handleComAtprotoSyncListRepos(ctx, cursor, limit) + if handleErr != nil { + return handleErr + } + return c.JSON(200, out) +} + +// HandleComAtprotoSyncGetRepo handles /xrpc/com.atproto.sync.getRepo +// returns 3xx to same URL at source PDS +func (s *BGS) HandleComAtprotoSyncGetRepo(c echo.Context) error { + // no request object, only params + params := c.QueryParams() + var did string + hasDid := false + for paramName, pvl := range params { + switch paramName { + case "did": + if len(pvl) == 1 { + did = pvl[0] + hasDid = true + } else if len(pvl) > 1 { + return c.JSON(http.StatusBadRequest, XRPCError{Message: "only allow one did param"}) + } + case "since": + // ok + default: + return c.JSON(http.StatusBadRequest, XRPCError{Message: fmt.Sprintf("invalid param: %s", paramName)}) + } + } + if !hasDid { + return c.JSON(http.StatusBadRequest, XRPCError{Message: "need did param"}) + } + + var pdsHostname string + err := s.db.Raw("SELECT pds.host FROM users JOIN pds ON users.pds = pds.id WHERE users.did = ?", did).Scan(&pdsHostname).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return c.JSON(http.StatusNotFound, XRPCError{Message: "NULL"}) + } + s.log.Error("user.pds.host lookup", "err", err) + return c.JSON(http.StatusInternalServerError, XRPCError{Message: "sorry"}) + } + + nextUrl := *(c.Request().URL) + nextUrl.Host = pdsHostname + if nextUrl.Scheme == "" { + nextUrl.Scheme = "https" + } + return c.Redirect(http.StatusFound, nextUrl.String()) +} + +func (s *BGS) HandleComAtprotoSyncRequestCrawl(c echo.Context) error { + ctx, span := otel.Tracer("server").Start(c.Request().Context(), "HandleComAtprotoSyncRequestCrawl") + defer span.End() + + var body comatprototypes.SyncRequestCrawl_Input + if err := c.Bind(&body); err != nil { + return c.JSON(http.StatusBadRequest, XRPCError{Message: fmt.Sprintf("invalid body: %s", err)}) + } + var handleErr error + // func (s *BGS) handleComAtprotoSyncRequestCrawl(ctx context.Context,body *comatprototypes.SyncRequestCrawl_Input) error + handleErr = s.handleComAtprotoSyncRequestCrawl(ctx, &body) + if handleErr != nil { + return handleErr + } + return nil +} diff --git a/cmd/relay/events/cbor_gen.go b/cmd/relay/events/cbor_gen.go new file mode 100644 index 000000000..8e13f8339 --- /dev/null +++ b/cmd/relay/events/cbor_gen.go @@ -0,0 +1,303 @@ +// Code generated by github.com/whyrusleeping/cbor-gen. DO NOT EDIT. + +package events + +import ( + "fmt" + "io" + "math" + "sort" + + cid "github.com/ipfs/go-cid" + cbg "github.com/whyrusleeping/cbor-gen" + xerrors "golang.org/x/xerrors" +) + +var _ = xerrors.Errorf +var _ = cid.Undef +var _ = math.E +var _ = sort.Sort + +func (t *EventHeader) MarshalCBOR(w io.Writer) error { + if t == nil { + _, err := w.Write(cbg.CborNull) + return err + } + + cw := cbg.NewCborWriter(w) + + if _, err := cw.Write([]byte{162}); err != nil { + return err + } + + // t.MsgType (string) (string) + if len("t") > 1000000 { + return xerrors.Errorf("Value in field \"t\" was too long") + } + + if err := cw.WriteMajorTypeHeader(cbg.MajTextString, uint64(len("t"))); err != nil { + return err + } + if _, err := cw.WriteString(string("t")); err != nil { + return err + } + + if len(t.MsgType) > 1000000 { + return xerrors.Errorf("Value in field t.MsgType was too long") + } + + if err := cw.WriteMajorTypeHeader(cbg.MajTextString, uint64(len(t.MsgType))); err != nil { + return err + } + if _, err := cw.WriteString(string(t.MsgType)); err != nil { + return err + } + + // t.Op (int64) (int64) + if len("op") > 1000000 { + return xerrors.Errorf("Value in field \"op\" was too long") + } + + if err := cw.WriteMajorTypeHeader(cbg.MajTextString, uint64(len("op"))); err != nil { + return err + } + if _, err := cw.WriteString(string("op")); err != nil { + return err + } + + if t.Op >= 0 { + if err := cw.WriteMajorTypeHeader(cbg.MajUnsignedInt, uint64(t.Op)); err != nil { + return err + } + } else { + if err := cw.WriteMajorTypeHeader(cbg.MajNegativeInt, uint64(-t.Op-1)); err != nil { + return err + } + } + + return nil +} + +func (t *EventHeader) UnmarshalCBOR(r io.Reader) (err error) { + *t = EventHeader{} + + cr := cbg.NewCborReader(r) + + maj, extra, err := cr.ReadHeader() + if err != nil { + return err + } + defer func() { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + }() + + if maj != cbg.MajMap { + return fmt.Errorf("cbor input should be of type map") + } + + if extra > cbg.MaxLength { + return fmt.Errorf("EventHeader: map struct too large (%d)", extra) + } + + n := extra + + nameBuf := make([]byte, 2) + for i := uint64(0); i < n; i++ { + nameLen, ok, err := cbg.ReadFullStringIntoBuf(cr, nameBuf, 1000000) + if err != nil { + return err + } + + if !ok { + // Field doesn't exist on this type, so ignore it + if err := cbg.ScanForLinks(cr, func(cid.Cid) {}); err != nil { + return err + } + continue + } + + switch string(nameBuf[:nameLen]) { + // t.MsgType (string) (string) + case "t": + + { + sval, err := cbg.ReadStringWithMax(cr, 1000000) + if err != nil { + return err + } + + t.MsgType = string(sval) + } + // t.Op (int64) (int64) + case "op": + { + maj, extra, err := cr.ReadHeader() + if err != nil { + return err + } + var extraI int64 + switch maj { + case cbg.MajUnsignedInt: + extraI = int64(extra) + if extraI < 0 { + return fmt.Errorf("int64 positive overflow") + } + case cbg.MajNegativeInt: + extraI = int64(extra) + if extraI < 0 { + return fmt.Errorf("int64 negative overflow") + } + extraI = -1 - extraI + default: + return fmt.Errorf("wrong type for int64 field: %d", maj) + } + + t.Op = int64(extraI) + } + + default: + // Field doesn't exist on this type, so ignore it + if err := cbg.ScanForLinks(r, func(cid.Cid) {}); err != nil { + return err + } + } + } + + return nil +} +func (t *ErrorFrame) MarshalCBOR(w io.Writer) error { + if t == nil { + _, err := w.Write(cbg.CborNull) + return err + } + + cw := cbg.NewCborWriter(w) + + if _, err := cw.Write([]byte{162}); err != nil { + return err + } + + // t.Error (string) (string) + if len("error") > 1000000 { + return xerrors.Errorf("Value in field \"error\" was too long") + } + + if err := cw.WriteMajorTypeHeader(cbg.MajTextString, uint64(len("error"))); err != nil { + return err + } + if _, err := cw.WriteString(string("error")); err != nil { + return err + } + + if len(t.Error) > 1000000 { + return xerrors.Errorf("Value in field t.Error was too long") + } + + if err := cw.WriteMajorTypeHeader(cbg.MajTextString, uint64(len(t.Error))); err != nil { + return err + } + if _, err := cw.WriteString(string(t.Error)); err != nil { + return err + } + + // t.Message (string) (string) + if len("message") > 1000000 { + return xerrors.Errorf("Value in field \"message\" was too long") + } + + if err := cw.WriteMajorTypeHeader(cbg.MajTextString, uint64(len("message"))); err != nil { + return err + } + if _, err := cw.WriteString(string("message")); err != nil { + return err + } + + if len(t.Message) > 1000000 { + return xerrors.Errorf("Value in field t.Message was too long") + } + + if err := cw.WriteMajorTypeHeader(cbg.MajTextString, uint64(len(t.Message))); err != nil { + return err + } + if _, err := cw.WriteString(string(t.Message)); err != nil { + return err + } + return nil +} + +func (t *ErrorFrame) UnmarshalCBOR(r io.Reader) (err error) { + *t = ErrorFrame{} + + cr := cbg.NewCborReader(r) + + maj, extra, err := cr.ReadHeader() + if err != nil { + return err + } + defer func() { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + }() + + if maj != cbg.MajMap { + return fmt.Errorf("cbor input should be of type map") + } + + if extra > cbg.MaxLength { + return fmt.Errorf("ErrorFrame: map struct too large (%d)", extra) + } + + n := extra + + nameBuf := make([]byte, 7) + for i := uint64(0); i < n; i++ { + nameLen, ok, err := cbg.ReadFullStringIntoBuf(cr, nameBuf, 1000000) + if err != nil { + return err + } + + if !ok { + // Field doesn't exist on this type, so ignore it + if err := cbg.ScanForLinks(cr, func(cid.Cid) {}); err != nil { + return err + } + continue + } + + switch string(nameBuf[:nameLen]) { + // t.Error (string) (string) + case "error": + + { + sval, err := cbg.ReadStringWithMax(cr, 1000000) + if err != nil { + return err + } + + t.Error = string(sval) + } + // t.Message (string) (string) + case "message": + + { + sval, err := cbg.ReadStringWithMax(cr, 1000000) + if err != nil { + return err + } + + t.Message = string(sval) + } + + default: + // Field doesn't exist on this type, so ignore it + if err := cbg.ScanForLinks(r, func(cid.Cid) {}); err != nil { + return err + } + } + } + + return nil +} diff --git a/cmd/relay/events/consumer.go b/cmd/relay/events/consumer.go new file mode 100644 index 000000000..e7090e971 --- /dev/null +++ b/cmd/relay/events/consumer.go @@ -0,0 +1,372 @@ +package events + +import ( + "context" + "fmt" + "io" + "log/slog" + "net" + "time" + + "github.com/RussellLuo/slidingwindow" + comatproto "github.com/bluesky-social/indigo/api/atproto" + "github.com/prometheus/client_golang/prometheus" + + "github.com/gorilla/websocket" +) + +type RepoStreamCallbacks struct { + RepoCommit func(evt *comatproto.SyncSubscribeRepos_Commit) error + RepoSync func(evt *comatproto.SyncSubscribeRepos_Sync) error + RepoHandle func(evt *comatproto.SyncSubscribeRepos_Handle) error + RepoIdentity func(evt *comatproto.SyncSubscribeRepos_Identity) error + RepoAccount func(evt *comatproto.SyncSubscribeRepos_Account) error + RepoInfo func(evt *comatproto.SyncSubscribeRepos_Info) error + RepoMigrate func(evt *comatproto.SyncSubscribeRepos_Migrate) error + RepoTombstone func(evt *comatproto.SyncSubscribeRepos_Tombstone) error + LabelLabels func(evt *comatproto.LabelSubscribeLabels_Labels) error + LabelInfo func(evt *comatproto.LabelSubscribeLabels_Info) error + Error func(evt *ErrorFrame) error +} + +func (rsc *RepoStreamCallbacks) EventHandler(ctx context.Context, xev *XRPCStreamEvent) error { + switch { + case xev.RepoCommit != nil && rsc.RepoCommit != nil: + return rsc.RepoCommit(xev.RepoCommit) + case xev.RepoSync != nil && rsc.RepoCommit != nil: + return rsc.RepoSync(xev.RepoSync) + case xev.RepoHandle != nil && rsc.RepoHandle != nil: + return rsc.RepoHandle(xev.RepoHandle) + case xev.RepoInfo != nil && rsc.RepoInfo != nil: + return rsc.RepoInfo(xev.RepoInfo) + case xev.RepoMigrate != nil && rsc.RepoMigrate != nil: + return rsc.RepoMigrate(xev.RepoMigrate) + case xev.RepoIdentity != nil && rsc.RepoIdentity != nil: + return rsc.RepoIdentity(xev.RepoIdentity) + case xev.RepoAccount != nil && rsc.RepoAccount != nil: + return rsc.RepoAccount(xev.RepoAccount) + case xev.RepoTombstone != nil && rsc.RepoTombstone != nil: + return rsc.RepoTombstone(xev.RepoTombstone) + case xev.LabelLabels != nil && rsc.LabelLabels != nil: + return rsc.LabelLabels(xev.LabelLabels) + case xev.LabelInfo != nil && rsc.LabelInfo != nil: + return rsc.LabelInfo(xev.LabelInfo) + case xev.Error != nil && rsc.Error != nil: + return rsc.Error(xev.Error) + default: + return nil + } +} + +type InstrumentedRepoStreamCallbacks struct { + limiters []*slidingwindow.Limiter + Next func(ctx context.Context, xev *XRPCStreamEvent) error +} + +func NewInstrumentedRepoStreamCallbacks(limiters []*slidingwindow.Limiter, next func(ctx context.Context, xev *XRPCStreamEvent) error) *InstrumentedRepoStreamCallbacks { + return &InstrumentedRepoStreamCallbacks{ + limiters: limiters, + Next: next, + } +} + +func waitForLimiter(ctx context.Context, lim *slidingwindow.Limiter) error { + if lim.Allow() { + return nil + } + + // wait until the limiter is ready (check every 100ms) + t := time.NewTicker(100 * time.Millisecond) + defer t.Stop() + + for !lim.Allow() { + select { + case <-ctx.Done(): + return ctx.Err() + case <-t.C: + } + } + + return nil +} + +func (rsc *InstrumentedRepoStreamCallbacks) EventHandler(ctx context.Context, xev *XRPCStreamEvent) error { + // Wait on all limiters before calling the next handler + for _, lim := range rsc.limiters { + if err := waitForLimiter(ctx, lim); err != nil { + return err + } + } + return rsc.Next(ctx, xev) +} + +type instrumentedReader struct { + r io.Reader + addr string + bytesCounter prometheus.Counter +} + +func (sr *instrumentedReader) Read(p []byte) (int, error) { + n, err := sr.r.Read(p) + sr.bytesCounter.Add(float64(n)) + return n, err +} + +// HandleRepoStream +// con is source of events +// sched gets AddWork for each event +// log may be nil for default logger +func HandleRepoStream(ctx context.Context, con *websocket.Conn, sched Scheduler, log *slog.Logger) error { + if log == nil { + log = slog.Default().With("system", "events") + } + ctx, cancel := context.WithCancel(ctx) + defer cancel() + defer sched.Shutdown() + + remoteAddr := con.RemoteAddr().String() + + go func() { + t := time.NewTicker(time.Second * 30) + defer t.Stop() + failcount := 0 + + for { + + select { + case <-t.C: + if err := con.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(time.Second*10)); err != nil { + log.Warn("failed to ping", "err", err) + failcount++ + if failcount >= 4 { + log.Error("too many ping fails", "count", failcount) + con.Close() + return + } + } else { + failcount = 0 // ok ping + } + case <-ctx.Done(): + con.Close() + return + } + } + }() + + con.SetPingHandler(func(message string) error { + err := con.WriteControl(websocket.PongMessage, []byte(message), time.Now().Add(time.Second*60)) + if err == websocket.ErrCloseSent { + return nil + } else if e, ok := err.(net.Error); ok && e.Temporary() { + return nil + } + return err + }) + + con.SetPongHandler(func(_ string) error { + if err := con.SetReadDeadline(time.Now().Add(time.Minute)); err != nil { + log.Error("failed to set read deadline", "err", err) + } + + return nil + }) + + lastSeq := int64(-1) + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + mt, rawReader, err := con.NextReader() + if err != nil { + return fmt.Errorf("con err at read: %w", err) + } + + switch mt { + default: + return fmt.Errorf("expected binary message from subscription endpoint") + case websocket.BinaryMessage: + // ok + } + + r := &instrumentedReader{ + r: rawReader, + addr: remoteAddr, + bytesCounter: bytesFromStreamCounter.WithLabelValues(remoteAddr), + } + + var header EventHeader + if err := header.UnmarshalCBOR(r); err != nil { + return fmt.Errorf("reading header: %w", err) + } + + eventsFromStreamCounter.WithLabelValues(remoteAddr).Inc() + + switch header.Op { + case EvtKindMessage: + switch header.MsgType { + case "#commit": + var evt comatproto.SyncSubscribeRepos_Commit + if err := evt.UnmarshalCBOR(r); err != nil { + return fmt.Errorf("reading repoCommit event: %w", err) + } + + if evt.Seq < lastSeq { + log.Error("Got events out of order from stream", "seq", evt.Seq, "prev", lastSeq) + } + + lastSeq = evt.Seq + + if err := sched.AddWork(ctx, evt.Repo, &XRPCStreamEvent{ + RepoCommit: &evt, + }); err != nil { + return err + } + case "#sync": + var evt comatproto.SyncSubscribeRepos_Sync + if err := evt.UnmarshalCBOR(r); err != nil { + return fmt.Errorf("reading repoCommit event: %w", err) + } + + if evt.Seq < lastSeq { + log.Error("Got events out of order from stream", "seq", evt.Seq, "prev", lastSeq) + } + + lastSeq = evt.Seq + + if err := sched.AddWork(ctx, evt.Did, &XRPCStreamEvent{ + RepoSync: &evt, + }); err != nil { + return err + } + case "#handle": + var evt comatproto.SyncSubscribeRepos_Handle + if err := evt.UnmarshalCBOR(r); err != nil { + return err + } + + if evt.Seq < lastSeq { + log.Error("Got events out of order from stream", "seq", evt.Seq, "prev", lastSeq) + } + lastSeq = evt.Seq + + if err := sched.AddWork(ctx, evt.Did, &XRPCStreamEvent{ + RepoHandle: &evt, + }); err != nil { + return err + } + case "#identity": + var evt comatproto.SyncSubscribeRepos_Identity + if err := evt.UnmarshalCBOR(r); err != nil { + return err + } + + if evt.Seq < lastSeq { + log.Error("Got events out of order from stream", "seq", evt.Seq, "prev", lastSeq) + } + lastSeq = evt.Seq + + if err := sched.AddWork(ctx, evt.Did, &XRPCStreamEvent{ + RepoIdentity: &evt, + }); err != nil { + return err + } + case "#account": + var evt comatproto.SyncSubscribeRepos_Account + if err := evt.UnmarshalCBOR(r); err != nil { + return err + } + + if evt.Seq < lastSeq { + log.Error("Got events out of order from stream", "seq", evt.Seq, "prev", lastSeq) + } + lastSeq = evt.Seq + + if err := sched.AddWork(ctx, evt.Did, &XRPCStreamEvent{ + RepoAccount: &evt, + }); err != nil { + return err + } + case "#info": + // TODO: this might also be a LabelInfo (as opposed to RepoInfo) + var evt comatproto.SyncSubscribeRepos_Info + if err := evt.UnmarshalCBOR(r); err != nil { + return err + } + + if err := sched.AddWork(ctx, "", &XRPCStreamEvent{ + RepoInfo: &evt, + }); err != nil { + return err + } + case "#migrate": + var evt comatproto.SyncSubscribeRepos_Migrate + if err := evt.UnmarshalCBOR(r); err != nil { + return err + } + + if evt.Seq < lastSeq { + log.Error("Got events out of order from stream", "seq", evt.Seq, "prev", lastSeq) + } + lastSeq = evt.Seq + + if err := sched.AddWork(ctx, evt.Did, &XRPCStreamEvent{ + RepoMigrate: &evt, + }); err != nil { + return err + } + case "#tombstone": + var evt comatproto.SyncSubscribeRepos_Tombstone + if err := evt.UnmarshalCBOR(r); err != nil { + return err + } + + if evt.Seq < lastSeq { + log.Error("Got events out of order from stream", "seq", evt.Seq, "prev", lastSeq) + } + lastSeq = evt.Seq + + if err := sched.AddWork(ctx, evt.Did, &XRPCStreamEvent{ + RepoTombstone: &evt, + }); err != nil { + return err + } + case "#labels": + var evt comatproto.LabelSubscribeLabels_Labels + if err := evt.UnmarshalCBOR(r); err != nil { + return fmt.Errorf("reading Labels event: %w", err) + } + + if evt.Seq < lastSeq { + log.Error("Got events out of order from stream", "seq", evt.Seq, "prev", lastSeq) + } + + lastSeq = evt.Seq + + if err := sched.AddWork(ctx, "", &XRPCStreamEvent{ + LabelLabels: &evt, + }); err != nil { + return err + } + } + + case EvtKindErrorFrame: + var errframe ErrorFrame + if err := errframe.UnmarshalCBOR(r); err != nil { + return err + } + + if err := sched.AddWork(ctx, "", &XRPCStreamEvent{ + Error: &errframe, + }); err != nil { + return err + } + + default: + return fmt.Errorf("unrecognized event stream type: %d", header.Op) + } + + } +} diff --git a/cmd/relay/events/diskpersist.go b/cmd/relay/events/diskpersist.go new file mode 100644 index 000000000..4b4b367e3 --- /dev/null +++ b/cmd/relay/events/diskpersist.go @@ -0,0 +1,983 @@ +package events + +import ( + "bufio" + "bytes" + "context" + "encoding/binary" + "errors" + "fmt" + "io" + "os" + "path/filepath" + "sync" + "time" + + "github.com/bluesky-social/indigo/api/atproto" + "github.com/bluesky-social/indigo/cmd/relay/models" + arc "github.com/hashicorp/golang-lru/arc/v2" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + cbg "github.com/whyrusleeping/cbor-gen" + "gorm.io/gorm" +) + +type DiskPersistence struct { + primaryDir string + archiveDir string + eventsPerFile int64 + writeBufferSize int + retention time.Duration + + meta *gorm.DB + + broadcast func(*XRPCStreamEvent) + + logfi *os.File + + eventCounter int64 + curSeq int64 + timeSequence bool + + uids UidSource + uidCache *arc.ARCCache[models.Uid, string] // TODO: unused + didCache *arc.ARCCache[string, models.Uid] + + writers *sync.Pool + buffers *sync.Pool + scratch []byte + + outbuf *bytes.Buffer + evtbuf []persistJob + + shutdown chan struct{} + + lk sync.Mutex +} + +type persistJob struct { + Bytes []byte + Evt *XRPCStreamEvent + Buffer *bytes.Buffer // so we can put it back in the pool when we're done +} + +type jobResult struct { + Err error + Seq int64 +} + +const ( + EvtFlagTakedown = 1 << iota + EvtFlagRebased +) + +var _ (EventPersistence) = (*DiskPersistence)(nil) + +type DiskPersistOptions struct { + UIDCacheSize int + DIDCacheSize int + EventsPerFile int64 + WriteBufferSize int + Retention time.Duration + + TimeSequence bool +} + +func DefaultDiskPersistOptions() *DiskPersistOptions { + return &DiskPersistOptions{ + EventsPerFile: 10_000, + UIDCacheSize: 1_000_000, + DIDCacheSize: 1_000_000, + WriteBufferSize: 50, + Retention: time.Hour * 24 * 3, // 3 days + } +} + +type UidSource interface { + DidToUid(ctx context.Context, did string) (models.Uid, error) +} + +func NewDiskPersistence(primaryDir, archiveDir string, db *gorm.DB, opts *DiskPersistOptions) (*DiskPersistence, error) { + if opts == nil { + opts = DefaultDiskPersistOptions() + } + + uidCache, err := arc.NewARC[models.Uid, string](opts.UIDCacheSize) + if err != nil { + return nil, fmt.Errorf("failed to create uid cache: %w", err) + } + + didCache, err := arc.NewARC[string, models.Uid](opts.DIDCacheSize) + if err != nil { + return nil, fmt.Errorf("failed to create did cache: %w", err) + } + + db.AutoMigrate(&LogFileRef{}) + + bufpool := &sync.Pool{ + New: func() any { + return new(bytes.Buffer) + }, + } + + wrpool := &sync.Pool{ + New: func() any { + return cbg.NewCborWriter(nil) + }, + } + + dp := &DiskPersistence{ + meta: db, + primaryDir: primaryDir, + archiveDir: archiveDir, + buffers: bufpool, + retention: opts.Retention, + writers: wrpool, + uidCache: uidCache, + didCache: didCache, + eventsPerFile: opts.EventsPerFile, + scratch: make([]byte, headerSize), + outbuf: new(bytes.Buffer), + writeBufferSize: opts.WriteBufferSize, + shutdown: make(chan struct{}), + timeSequence: opts.TimeSequence, + } + + if err := dp.resumeLog(); err != nil { + return nil, err + } + + go dp.flushRoutine() + + go dp.garbageCollectRoutine() + + return dp, nil +} + +type LogFileRef struct { + gorm.Model + Path string + Archived bool + SeqStart int64 +} + +func (dp *DiskPersistence) SetUidSource(uids UidSource) { + dp.uids = uids +} + +func (dp *DiskPersistence) resumeLog() error { + var lfr LogFileRef + if err := dp.meta.Order("seq_start desc").Limit(1).Find(&lfr).Error; err != nil { + return err + } + + if lfr.ID == 0 { + // no files, start anew! + return dp.initLogFile() + } + + // 0 for the mode is fine since that is only used if O_CREAT is passed + fi, err := os.OpenFile(filepath.Join(dp.primaryDir, lfr.Path), os.O_RDWR, 0) + if err != nil { + return err + } + + seq, err := scanForLastSeq(fi, -1) + if err != nil { + return fmt.Errorf("failed to scan log file for last seqno: %w", err) + } + + dp.curSeq = seq + 1 + dp.logfi = fi + + return nil +} + +func (dp *DiskPersistence) initLogFile() error { + if err := os.MkdirAll(dp.primaryDir, 0775); err != nil { + return err + } + + p := filepath.Join(dp.primaryDir, "evts-0") + fi, err := os.Create(p) + if err != nil { + return err + } + + if err := dp.meta.Create(&LogFileRef{ + Path: "evts-0", + SeqStart: 0, + }).Error; err != nil { + return err + } + + dp.logfi = fi + dp.curSeq = 1 + return nil +} + +// swapLog swaps the current log file out for a new empty one +// must only be called while holding dp.lk +func (dp *DiskPersistence) swapLog(ctx context.Context) error { + if err := dp.logfi.Close(); err != nil { + return fmt.Errorf("failed to close current log file: %w", err) + } + + fname := fmt.Sprintf("evts-%d", dp.curSeq) + nextp := filepath.Join(dp.primaryDir, fname) + + fi, err := os.Create(nextp) + if err != nil { + return err + } + + if err := dp.meta.Create(&LogFileRef{ + Path: fname, + SeqStart: dp.curSeq, + }).Error; err != nil { + return err + } + + dp.logfi = fi + return nil +} + +func scanForLastSeq(fi *os.File, end int64) (int64, error) { + scratch := make([]byte, headerSize) + + var lastSeq int64 = -1 + var offset int64 + for { + eh, err := readHeader(fi, scratch) + if err != nil { + if errors.Is(err, io.EOF) { + return lastSeq, nil + } + return 0, err + } + + if end > 0 && eh.Seq > end { + // return to beginning of offset + n, err := fi.Seek(offset, io.SeekStart) + if err != nil { + return 0, err + } + + if n != offset { + return 0, fmt.Errorf("rewind seek failed") + } + + return eh.Seq, nil + } + + lastSeq = eh.Seq + + noff, err := fi.Seek(int64(eh.Len), io.SeekCurrent) + if err != nil { + return 0, err + } + + if noff != offset+headerSize+int64(eh.Len) { + // TODO: must recover from this + return 0, fmt.Errorf("did not seek to next event properly") + } + + offset = noff + } +} + +const ( + evtKindCommit = 1 + evtKindHandle = 2 + evtKindTombstone = 3 + evtKindIdentity = 4 + evtKindAccount = 5 +) + +var emptyHeader = make([]byte, headerSize) + +func (dp *DiskPersistence) addJobToQueue(ctx context.Context, job persistJob) error { + dp.lk.Lock() + defer dp.lk.Unlock() + + if err := dp.doPersist(ctx, job); err != nil { + return err + } + + // TODO: for some reason replacing this constant with p.writeBufferSize dramatically reduces perf... + if len(dp.evtbuf) > 400 { + if err := dp.flushLog(ctx); err != nil { + return fmt.Errorf("failed to flush disk log: %w", err) + } + } + + return nil +} + +func (dp *DiskPersistence) flushRoutine() { + t := time.NewTicker(time.Millisecond * 100) + + for { + ctx := context.Background() + select { + case <-dp.shutdown: + return + case <-t.C: + dp.lk.Lock() + if err := dp.flushLog(ctx); err != nil { + // TODO: this happening is quite bad. Need a recovery strategy + log.Error("failed to flush disk log", "err", err) + } + dp.lk.Unlock() + } + } +} + +func (dp *DiskPersistence) flushLog(ctx context.Context) error { + if len(dp.evtbuf) == 0 { + return nil + } + + _, err := io.Copy(dp.logfi, dp.outbuf) + if err != nil { + return err + } + + dp.outbuf.Truncate(0) + + for _, ej := range dp.evtbuf { + dp.broadcast(ej.Evt) + ej.Buffer.Truncate(0) + dp.buffers.Put(ej.Buffer) + } + + dp.evtbuf = dp.evtbuf[:0] + + return nil +} + +func (dp *DiskPersistence) garbageCollectRoutine() { + t := time.NewTicker(time.Hour) + + for { + ctx := context.Background() + select { + // Closing a channel can be listened to with multiple routines: https://goplay.tools/snippet/UcwbC0CeJAL + case <-dp.shutdown: + return + case <-t.C: + if errs := dp.garbageCollect(ctx); len(errs) > 0 { + for _, err := range errs { + log.Error("garbage collection error", "err", err) + } + } + } + } +} + +var garbageCollectionsExecuted = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "disk_persister_garbage_collections_executed", + Help: "Number of garbage collections executed", +}, []string{}) + +var garbageCollectionErrors = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "disk_persister_garbage_collections_errors", + Help: "Number of errors encountered during garbage collection", +}, []string{}) + +var refsGarbageCollected = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "disk_persister_garbage_collections_refs_collected", + Help: "Number of refs collected during garbage collection", +}, []string{}) + +var filesGarbageCollected = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "disk_persister_garbage_collections_files_collected", + Help: "Number of files collected during garbage collection", +}, []string{}) + +func (dp *DiskPersistence) garbageCollect(ctx context.Context) []error { + garbageCollectionsExecuted.WithLabelValues().Inc() + + // Grab refs created before the retention period + var refs []LogFileRef + var errs []error + + defer func() { + garbageCollectionErrors.WithLabelValues().Add(float64(len(errs))) + }() + + if err := dp.meta.WithContext(ctx).Find(&refs, "created_at < ?", time.Now().Add(-dp.retention)).Error; err != nil { + return []error{err} + } + + oldRefsFound := len(refs) + refsDeleted := 0 + filesDeleted := 0 + + // In the future if we want to support Archiving, we could do that here instead of deleting + for _, r := range refs { + dp.lk.Lock() + currentLogfile := dp.logfi.Name() + dp.lk.Unlock() + + if filepath.Join(dp.primaryDir, r.Path) == currentLogfile { + // Don't delete the current log file + log.Info("skipping deletion of current log file") + continue + } + + // Delete the ref in the database to prevent playback from finding it + if err := dp.meta.WithContext(ctx).Delete(&r).Error; err != nil { + errs = append(errs, err) + continue + } + refsDeleted++ + + // Delete the file from disk + if err := os.Remove(filepath.Join(dp.primaryDir, r.Path)); err != nil { + errs = append(errs, err) + continue + } + filesDeleted++ + } + + refsGarbageCollected.WithLabelValues().Add(float64(refsDeleted)) + filesGarbageCollected.WithLabelValues().Add(float64(filesDeleted)) + + log.Info("garbage collection complete", + "filesDeleted", filesDeleted, + "refsDeleted", refsDeleted, + "oldRefsFound", oldRefsFound, + ) + + return errs +} + +func (dp *DiskPersistence) doPersist(ctx context.Context, j persistJob) error { + b := j.Bytes + e := j.Evt + seq := dp.curSeq + if dp.timeSequence { + seq = time.Now().UnixMicro() + if seq < dp.curSeq { + seq = dp.curSeq + } + dp.curSeq = seq + 1 + } else { + dp.curSeq++ + } + + // Set sequence number in event header + binary.LittleEndian.PutUint64(b[20:], uint64(seq)) + + switch { + case e.RepoCommit != nil: + cc := *e.RepoCommit + cc.Seq = seq + j.Evt.RepoCommit = &cc + case e.RepoHandle != nil: + hc := *e.RepoHandle + hc.Seq = seq + j.Evt.RepoHandle = &hc + case e.RepoIdentity != nil: + ic := *e.RepoIdentity + ic.Seq = seq + j.Evt.RepoIdentity = &ic + case e.RepoAccount != nil: + ac := *e.RepoAccount + ac.Seq = seq + j.Evt.RepoAccount = &ac + case e.RepoTombstone != nil: + tc := *e.RepoTombstone + tc.Seq = seq + j.Evt.RepoTombstone = &tc + default: + // only those three get peristed right now + // we should not actually ever get here... + return nil + } + + // TODO: does this guarantee a full write? + _, err := dp.outbuf.Write(b) + if err != nil { + return err + } + + dp.evtbuf = append(dp.evtbuf, j) + + dp.eventCounter++ + if dp.eventCounter%dp.eventsPerFile == 0 { + if err := dp.flushLog(ctx); err != nil { + return err + } + + // time to roll the log file + if err := dp.swapLog(ctx); err != nil { + return err + } + } + + return nil +} + +func (dp *DiskPersistence) Persist(ctx context.Context, e *XRPCStreamEvent) error { + buffer := dp.buffers.Get().(*bytes.Buffer) + cw := dp.writers.Get().(*cbg.CborWriter) + cw.SetWriter(buffer) + + buffer.Truncate(0) + + buffer.Write(emptyHeader) + + var did string + var evtKind uint32 + switch { + case e.RepoCommit != nil: + evtKind = evtKindCommit + did = e.RepoCommit.Repo + if err := e.RepoCommit.MarshalCBOR(cw); err != nil { + return fmt.Errorf("failed to marshal: %w", err) + } + case e.RepoHandle != nil: + evtKind = evtKindHandle + did = e.RepoHandle.Did + if err := e.RepoHandle.MarshalCBOR(cw); err != nil { + return fmt.Errorf("failed to marshal: %w", err) + } + case e.RepoIdentity != nil: + evtKind = evtKindIdentity + did = e.RepoIdentity.Did + if err := e.RepoIdentity.MarshalCBOR(cw); err != nil { + return fmt.Errorf("failed to marshal: %w", err) + } + case e.RepoAccount != nil: + evtKind = evtKindAccount + did = e.RepoAccount.Did + if err := e.RepoAccount.MarshalCBOR(cw); err != nil { + return fmt.Errorf("failed to marshal: %w", err) + } + case e.RepoTombstone != nil: + evtKind = evtKindTombstone + did = e.RepoTombstone.Did + if err := e.RepoTombstone.MarshalCBOR(cw); err != nil { + return fmt.Errorf("failed to marshal: %w", err) + } + default: + return nil + // only those two get peristed right now + } + + usr, err := dp.uidForDid(ctx, did) + if err != nil { + return err + } + + b := buffer.Bytes() + + // Set flags in header (no flags for now) + binary.LittleEndian.PutUint32(b, 0) + // Set event kind in header + binary.LittleEndian.PutUint32(b[4:], evtKind) + // Set event length in header + binary.LittleEndian.PutUint32(b[8:], uint32(len(b)-headerSize)) + // Set user UID in header + binary.LittleEndian.PutUint64(b[12:], uint64(usr)) + + return dp.addJobToQueue(ctx, persistJob{ + Bytes: b, + Evt: e, + Buffer: buffer, + }) +} + +type evtHeader struct { + Flags uint32 + Kind uint32 + Seq int64 + Usr models.Uid + Len uint32 +} + +func (eh *evtHeader) Len64() int64 { + return int64(eh.Len) +} + +const headerSize = 4 + 4 + 4 + 8 + 8 + +func readHeader(r io.Reader, scratch []byte) (*evtHeader, error) { + if len(scratch) < headerSize { + return nil, fmt.Errorf("must pass scratch buffer of at least %d bytes", headerSize) + } + + scratch = scratch[:headerSize] + _, err := io.ReadFull(r, scratch) + if err != nil { + return nil, fmt.Errorf("reading header: %w", err) + } + + flags := binary.LittleEndian.Uint32(scratch[:4]) + kind := binary.LittleEndian.Uint32(scratch[4:8]) + l := binary.LittleEndian.Uint32(scratch[8:12]) + usr := binary.LittleEndian.Uint64(scratch[12:20]) + seq := binary.LittleEndian.Uint64(scratch[20:28]) + + return &evtHeader{ + Flags: flags, + Kind: kind, + Len: l, + Usr: models.Uid(usr), + Seq: int64(seq), + }, nil +} + +func (dp *DiskPersistence) writeHeader(ctx context.Context, flags uint32, kind uint32, l uint32, usr uint64, seq int64) error { + binary.LittleEndian.PutUint32(dp.scratch, flags) + binary.LittleEndian.PutUint32(dp.scratch[4:], kind) + binary.LittleEndian.PutUint32(dp.scratch[8:], l) + binary.LittleEndian.PutUint64(dp.scratch[12:], usr) + binary.LittleEndian.PutUint64(dp.scratch[20:], uint64(seq)) + + nw, err := dp.logfi.Write(dp.scratch) + if err != nil { + return err + } + + if nw != headerSize { + return fmt.Errorf("only wrote %d bytes for header", nw) + } + + return nil +} + +func (dp *DiskPersistence) uidForDid(ctx context.Context, did string) (models.Uid, error) { + if uid, ok := dp.didCache.Get(did); ok { + return uid, nil + } + + uid, err := dp.uids.DidToUid(ctx, did) + if err != nil { + return 0, err + } + + dp.didCache.Add(did, uid) + + return uid, nil +} + +func (dp *DiskPersistence) Playback(ctx context.Context, since int64, cb func(*XRPCStreamEvent) error) error { + var logs []LogFileRef + needslogs := true + if since != 0 { + // find the log file that starts before our since + result := dp.meta.Debug().Order("seq_start desc").Where("seq_start < ?", since).Limit(1).Find(&logs) + if result.Error != nil { + return result.Error + } + if result.RowsAffected != 0 { + needslogs = false + } + } + + // playback data from all the log files we found, then check the db to see if more were written during playback. + // repeat a few times but not unboundedly. + // don't decrease '10' below 2 because we should always do two passes through this if the above before-chunk query was used. + for i := 0; i < 10; i++ { + if needslogs { + if err := dp.meta.Debug().Order("seq_start asc").Find(&logs, "seq_start >= ?", since).Error; err != nil { + return err + } + } + + lastSeq, err := dp.PlaybackLogfiles(ctx, since, cb, logs) + if err != nil { + return err + } + + // No lastSeq implies that we read until the end of known events + if lastSeq == nil { + break + } + + since = *lastSeq + needslogs = true + } + + return nil +} + +func (dp *DiskPersistence) PlaybackLogfiles(ctx context.Context, since int64, cb func(*XRPCStreamEvent) error, logFiles []LogFileRef) (*int64, error) { + for i, lf := range logFiles { + lastSeq, err := dp.readEventsFrom(ctx, since, filepath.Join(dp.primaryDir, lf.Path), cb) + if err != nil { + return nil, err + } + since = 0 + if i == len(logFiles)-1 && + lastSeq != nil && + (*lastSeq-lf.SeqStart) == dp.eventsPerFile-1 { + // There may be more log files to read since the last one was full + return lastSeq, nil + } + } + + return nil, nil +} + +func postDoNotEmit(flags uint32) bool { + if flags&(EvtFlagRebased|EvtFlagTakedown) != 0 { + return true + } + + return false +} + +func (dp *DiskPersistence) readEventsFrom(ctx context.Context, since int64, fn string, cb func(*XRPCStreamEvent) error) (*int64, error) { + fi, err := os.OpenFile(fn, os.O_RDONLY, 0) + if err != nil { + return nil, err + } + + if since != 0 { + lastSeq, err := scanForLastSeq(fi, since) + if err != nil { + return nil, err + } + if since > lastSeq { + log.Error("playback cursor is greater than last seq of file checked", + "since", since, + "lastSeq", lastSeq, + "filename", fn, + ) + return nil, nil + } + } + + bufr := bufio.NewReader(fi) + + lastSeq := int64(0) + + scratch := make([]byte, headerSize) + for { + h, err := readHeader(bufr, scratch) + if err != nil { + if errors.Is(err, io.EOF) { + return &lastSeq, nil + } + + return nil, err + } + + lastSeq = h.Seq + + if postDoNotEmit(h.Flags) { + // event taken down, skip + _, err := io.CopyN(io.Discard, bufr, h.Len64()) // would be really nice if the buffered reader had a 'skip' method that does a seek under the hood + if err != nil { + return nil, fmt.Errorf("failed while skipping event (seq: %d, fn: %q): %w", h.Seq, fn, err) + } + continue + } + + switch h.Kind { + case evtKindCommit: + var evt atproto.SyncSubscribeRepos_Commit + if err := evt.UnmarshalCBOR(io.LimitReader(bufr, h.Len64())); err != nil { + return nil, err + } + evt.Seq = h.Seq + if err := cb(&XRPCStreamEvent{RepoCommit: &evt}); err != nil { + return nil, err + } + case evtKindHandle: + var evt atproto.SyncSubscribeRepos_Handle + if err := evt.UnmarshalCBOR(io.LimitReader(bufr, h.Len64())); err != nil { + return nil, err + } + evt.Seq = h.Seq + if err := cb(&XRPCStreamEvent{RepoHandle: &evt}); err != nil { + return nil, err + } + case evtKindIdentity: + var evt atproto.SyncSubscribeRepos_Identity + if err := evt.UnmarshalCBOR(io.LimitReader(bufr, h.Len64())); err != nil { + return nil, err + } + evt.Seq = h.Seq + if err := cb(&XRPCStreamEvent{RepoIdentity: &evt}); err != nil { + return nil, err + } + case evtKindAccount: + var evt atproto.SyncSubscribeRepos_Account + if err := evt.UnmarshalCBOR(io.LimitReader(bufr, h.Len64())); err != nil { + return nil, err + } + evt.Seq = h.Seq + if err := cb(&XRPCStreamEvent{RepoAccount: &evt}); err != nil { + return nil, err + } + case evtKindTombstone: + var evt atproto.SyncSubscribeRepos_Tombstone + if err := evt.UnmarshalCBOR(io.LimitReader(bufr, h.Len64())); err != nil { + return nil, err + } + evt.Seq = h.Seq + if err := cb(&XRPCStreamEvent{RepoTombstone: &evt}); err != nil { + return nil, err + } + default: + log.Warn("unrecognized event kind coming from log file", "seq", h.Seq, "kind", h.Kind) + return nil, fmt.Errorf("halting on unrecognized event kind") + } + } +} + +type UserAction struct { + gorm.Model + + Usr models.Uid + RebaseAt int64 + Takedown bool +} + +func (dp *DiskPersistence) TakeDownRepo(ctx context.Context, usr models.Uid) error { + /* + if err := p.meta.Create(&UserAction{ + Usr: usr, + Takedown: true, + }).Error; err != nil { + return err + } + */ + + return dp.forEachShardWithUserEvents(ctx, usr, func(ctx context.Context, fn string) error { + if err := dp.deleteEventsForUser(ctx, usr, fn); err != nil { + return err + } + + return nil + }) +} + +func (dp *DiskPersistence) forEachShardWithUserEvents(ctx context.Context, usr models.Uid, cb func(context.Context, string) error) error { + var refs []LogFileRef + if err := dp.meta.Order("created_at desc").Find(&refs).Error; err != nil { + return err + } + + for _, r := range refs { + mhas, err := dp.refMaybeHasUserEvents(ctx, usr, r) + if err != nil { + return err + } + + if mhas { + var path string + if r.Archived { + path = filepath.Join(dp.archiveDir, r.Path) + } else { + path = filepath.Join(dp.primaryDir, r.Path) + } + + if err := cb(ctx, path); err != nil { + return err + } + } + } + + return nil +} + +func (dp *DiskPersistence) refMaybeHasUserEvents(ctx context.Context, usr models.Uid, ref LogFileRef) (bool, error) { + // TODO: lazily computed bloom filters for users in each logfile + return true, nil +} + +type zeroReader struct{} + +func (zr *zeroReader) Read(p []byte) (n int, err error) { + for i := range p { + p[i] = 0 + } + return len(p), nil +} + +func (dp *DiskPersistence) deleteEventsForUser(ctx context.Context, usr models.Uid, fn string) error { + return dp.mutateUserEventsInLog(ctx, usr, fn, EvtFlagTakedown, true) +} + +func (dp *DiskPersistence) mutateUserEventsInLog(ctx context.Context, usr models.Uid, fn string, flag uint32, zeroEvts bool) error { + fi, err := os.OpenFile(fn, os.O_RDWR, 0) + if err != nil { + return fmt.Errorf("failed to open log file: %w", err) + } + defer fi.Close() + defer fi.Sync() + + scratch := make([]byte, headerSize) + var offset int64 + for { + h, err := readHeader(fi, scratch) + if err != nil { + if errors.Is(err, io.EOF) { + return nil + } + + return err + } + + if h.Usr == usr && h.Flags&flag == 0 { + nflag := h.Flags | flag + + binary.LittleEndian.PutUint32(scratch, nflag) + + if _, err := fi.WriteAt(scratch[:4], offset); err != nil { + return fmt.Errorf("failed to write updated flag value: %w", err) + } + + if zeroEvts { + // sync that write before blanking the event data + if err := fi.Sync(); err != nil { + return err + } + + if _, err := fi.Seek(offset+headerSize, io.SeekStart); err != nil { + return fmt.Errorf("failed to seek: %w", err) + } + + _, err := io.CopyN(fi, &zeroReader{}, h.Len64()) + if err != nil { + return err + } + } + } + + offset += headerSize + h.Len64() + _, err = fi.Seek(offset, io.SeekStart) + if err != nil { + return fmt.Errorf("failed to seek: %w", err) + } + } +} + +func (dp *DiskPersistence) Flush(ctx context.Context) error { + dp.lk.Lock() + defer dp.lk.Unlock() + if len(dp.evtbuf) > 0 { + return dp.flushLog(ctx) + } + return nil +} + +func (dp *DiskPersistence) Shutdown(ctx context.Context) error { + close(dp.shutdown) + if err := dp.Flush(ctx); err != nil { + return err + } + + dp.logfi.Close() + return nil +} + +func (dp *DiskPersistence) SetEventBroadcaster(f func(*XRPCStreamEvent)) { + dp.broadcast = f +} diff --git a/cmd/relay/events/events.go b/cmd/relay/events/events.go new file mode 100644 index 000000000..08b8f18a1 --- /dev/null +++ b/cmd/relay/events/events.go @@ -0,0 +1,515 @@ +package events + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "log/slog" + "sync" + "time" + + comatproto "github.com/bluesky-social/indigo/api/atproto" + "github.com/bluesky-social/indigo/cmd/relay/models" + lexutil "github.com/bluesky-social/indigo/lex/util" + "github.com/prometheus/client_golang/prometheus" + + cbg "github.com/whyrusleeping/cbor-gen" + "go.opentelemetry.io/otel" +) + +var log = slog.Default().With("system", "events") + +type Scheduler interface { + AddWork(ctx context.Context, repo string, val *XRPCStreamEvent) error + Shutdown() +} + +type EventManager struct { + subs []*Subscriber + subsLk sync.Mutex + + bufferSize int + crossoverBufferSize int + + persister EventPersistence + + log *slog.Logger +} + +func NewEventManager(persister EventPersistence) *EventManager { + em := &EventManager{ + bufferSize: 16 << 10, + crossoverBufferSize: 512, + persister: persister, + log: slog.Default().With("system", "events"), + } + + persister.SetEventBroadcaster(em.broadcastEvent) + + return em +} + +const ( + opSubscribe = iota + opUnsubscribe + opSend +) + +type Operation struct { + op int + sub *Subscriber + evt *XRPCStreamEvent +} + +func (em *EventManager) Shutdown(ctx context.Context) error { + return em.persister.Shutdown(ctx) +} + +func (em *EventManager) broadcastEvent(evt *XRPCStreamEvent) { + // the main thing we do is send it out, so MarshalCBOR once + if err := evt.Preserialize(); err != nil { + em.log.Error("broadcast serialize failed", "err", err) + // serialize isn't going to go better later, this event is cursed + return + } + + em.subsLk.Lock() + defer em.subsLk.Unlock() + + // TODO: for a larger fanout we should probably have dedicated goroutines + // for subsets of the subscriber set, and tiered channels to distribute + // events out to them, or some similar architecture + // Alternatively, we might just want to not allow too many subscribers + // directly to the bgs, and have rebroadcasting proxies instead + for _, s := range em.subs { + if s.filter(evt) { + s.enqueuedCounter.Inc() + select { + case s.outgoing <- evt: + case <-s.done: + default: + // filter out all future messages that would be + // sent to this subscriber, but wait for it to + // actually be removed by the correct bit of + // code + s.filter = func(*XRPCStreamEvent) bool { return false } + + em.log.Warn("dropping slow consumer due to event overflow", "bufferSize", len(s.outgoing), "ident", s.ident) + go func(torem *Subscriber) { + torem.lk.Lock() + if !torem.cleanedUp { + select { + case torem.outgoing <- &XRPCStreamEvent{ + Error: &ErrorFrame{ + Error: "ConsumerTooSlow", + }, + }: + case <-time.After(time.Second * 5): + em.log.Warn("failed to send error frame to backed up consumer", "ident", torem.ident) + } + } + torem.lk.Unlock() + torem.cleanup() + }(s) + } + s.broadcastCounter.Inc() + } + } +} + +func (em *EventManager) persistAndSendEvent(ctx context.Context, evt *XRPCStreamEvent) { + // TODO: can cut 5-10% off of disk persister benchmarks by making this function + // accept a uid. The lookup inside the persister is notably expensive (despite + // being an lru cache?) + if err := em.persister.Persist(ctx, evt); err != nil { + em.log.Error("failed to persist outbound event", "err", err) + } +} + +type Subscriber struct { + outgoing chan *XRPCStreamEvent + + filter func(*XRPCStreamEvent) bool + + done chan struct{} + + cleanup func() + + lk sync.Mutex + cleanedUp bool + + ident string + enqueuedCounter prometheus.Counter + broadcastCounter prometheus.Counter +} + +const ( + EvtKindErrorFrame = -1 + EvtKindMessage = 1 +) + +type EventHeader struct { + Op int64 `cborgen:"op"` + MsgType string `cborgen:"t"` +} + +var ( + AccountStatusActive = "active" + AccountStatusTakendown = "takendown" + AccountStatusSuspended = "suspended" + AccountStatusDeleted = "deleted" + AccountStatusDeactivated = "deactivated" +) + +type XRPCStreamEvent struct { + Error *ErrorFrame + RepoCommit *comatproto.SyncSubscribeRepos_Commit + RepoSync *comatproto.SyncSubscribeRepos_Sync + RepoHandle *comatproto.SyncSubscribeRepos_Handle + RepoIdentity *comatproto.SyncSubscribeRepos_Identity + RepoInfo *comatproto.SyncSubscribeRepos_Info + RepoMigrate *comatproto.SyncSubscribeRepos_Migrate + RepoTombstone *comatproto.SyncSubscribeRepos_Tombstone + RepoAccount *comatproto.SyncSubscribeRepos_Account + LabelLabels *comatproto.LabelSubscribeLabels_Labels + LabelInfo *comatproto.LabelSubscribeLabels_Info + + // some private fields for internal routing perf + PrivUid models.Uid `json:"-" cborgen:"-"` + PrivPdsId uint `json:"-" cborgen:"-"` + PrivRelevantPds []uint `json:"-" cborgen:"-"` + Preserialized []byte `json:"-" cborgen:"-"` +} + +func (evt *XRPCStreamEvent) Serialize(wc io.Writer) error { + header := EventHeader{Op: EvtKindMessage} + var obj lexutil.CBOR + + switch { + case evt.Error != nil: + header.Op = EvtKindErrorFrame + obj = evt.Error + case evt.RepoCommit != nil: + header.MsgType = "#commit" + obj = evt.RepoCommit + case evt.RepoSync != nil: + header.MsgType = "#sync" + obj = evt.RepoSync + case evt.RepoHandle != nil: + header.MsgType = "#handle" + obj = evt.RepoHandle + case evt.RepoIdentity != nil: + header.MsgType = "#identity" + obj = evt.RepoIdentity + case evt.RepoAccount != nil: + header.MsgType = "#account" + obj = evt.RepoAccount + case evt.RepoInfo != nil: + header.MsgType = "#info" + obj = evt.RepoInfo + case evt.RepoMigrate != nil: + header.MsgType = "#migrate" + obj = evt.RepoMigrate + case evt.RepoTombstone != nil: + header.MsgType = "#tombstone" + obj = evt.RepoTombstone + default: + return fmt.Errorf("unrecognized event kind") + } + + cborWriter := cbg.NewCborWriter(wc) + if err := header.MarshalCBOR(cborWriter); err != nil { + return fmt.Errorf("failed to write header: %w", err) + } + return obj.MarshalCBOR(cborWriter) +} + +func (xevt *XRPCStreamEvent) Deserialize(r io.Reader) error { + var header EventHeader + if err := header.UnmarshalCBOR(r); err != nil { + return fmt.Errorf("reading header: %w", err) + } + switch header.Op { + case EvtKindMessage: + switch header.MsgType { + case "#commit": + var evt comatproto.SyncSubscribeRepos_Commit + if err := evt.UnmarshalCBOR(r); err != nil { + return fmt.Errorf("reading repoCommit event: %w", err) + } + xevt.RepoCommit = &evt + case "#sync": + var evt comatproto.SyncSubscribeRepos_Sync + if err := evt.UnmarshalCBOR(r); err != nil { + return fmt.Errorf("reading repoCommit event: %w", err) + } + xevt.RepoSync = &evt + case "#handle": + var evt comatproto.SyncSubscribeRepos_Handle + if err := evt.UnmarshalCBOR(r); err != nil { + return err + } + xevt.RepoHandle = &evt + case "#identity": + var evt comatproto.SyncSubscribeRepos_Identity + if err := evt.UnmarshalCBOR(r); err != nil { + return err + } + xevt.RepoIdentity = &evt + case "#account": + var evt comatproto.SyncSubscribeRepos_Account + if err := evt.UnmarshalCBOR(r); err != nil { + return err + } + xevt.RepoAccount = &evt + case "#info": + // TODO: this might also be a LabelInfo (as opposed to RepoInfo) + var evt comatproto.SyncSubscribeRepos_Info + if err := evt.UnmarshalCBOR(r); err != nil { + return err + } + xevt.RepoInfo = &evt + case "#migrate": + var evt comatproto.SyncSubscribeRepos_Migrate + if err := evt.UnmarshalCBOR(r); err != nil { + return err + } + xevt.RepoMigrate = &evt + case "#tombstone": + var evt comatproto.SyncSubscribeRepos_Tombstone + if err := evt.UnmarshalCBOR(r); err != nil { + return err + } + xevt.RepoTombstone = &evt + case "#labels": + var evt comatproto.LabelSubscribeLabels_Labels + if err := evt.UnmarshalCBOR(r); err != nil { + return fmt.Errorf("reading Labels event: %w", err) + } + xevt.LabelLabels = &evt + } + case EvtKindErrorFrame: + var errframe ErrorFrame + if err := errframe.UnmarshalCBOR(r); err != nil { + return err + } + xevt.Error = &errframe + default: + return fmt.Errorf("unrecognized event stream type: %d", header.Op) + } + return nil +} + +var ErrNoSeq = errors.New("event has no sequence number") + +// serialize content into Preserialized cache +func (evt *XRPCStreamEvent) Preserialize() error { + if evt.Preserialized != nil { + return nil + } + var buf bytes.Buffer + err := evt.Serialize(&buf) + if err != nil { + return err + } + evt.Preserialized = buf.Bytes() + return nil +} + +type ErrorFrame struct { + Error string `cborgen:"error"` + Message string `cborgen:"message"` +} + +func (em *EventManager) AddEvent(ctx context.Context, ev *XRPCStreamEvent) error { + ctx, span := otel.Tracer("events").Start(ctx, "AddEvent") + defer span.End() + + em.persistAndSendEvent(ctx, ev) + return nil +} + +var ( + ErrPlaybackShutdown = fmt.Errorf("playback shutting down") + ErrCaughtUp = fmt.Errorf("caught up") +) + +func (em *EventManager) Subscribe(ctx context.Context, ident string, filter func(*XRPCStreamEvent) bool, since *int64) (<-chan *XRPCStreamEvent, func(), error) { + if filter == nil { + filter = func(*XRPCStreamEvent) bool { return true } + } + + done := make(chan struct{}) + sub := &Subscriber{ + ident: ident, + outgoing: make(chan *XRPCStreamEvent, em.bufferSize), + filter: filter, + done: done, + enqueuedCounter: eventsEnqueued.WithLabelValues(ident), + broadcastCounter: eventsBroadcast.WithLabelValues(ident), + } + + sub.cleanup = sync.OnceFunc(func() { + sub.lk.Lock() + defer sub.lk.Unlock() + close(done) + em.rmSubscriber(sub) + close(sub.outgoing) + sub.cleanedUp = true + }) + + if since == nil { + em.addSubscriber(sub) + return sub.outgoing, sub.cleanup, nil + } + + out := make(chan *XRPCStreamEvent, em.crossoverBufferSize) + + go func() { + lastSeq := *since + // run playback to get through *most* of the events, getting our current cursor close to realtime + if err := em.persister.Playback(ctx, *since, func(e *XRPCStreamEvent) error { + select { + case <-done: + return ErrPlaybackShutdown + case out <- e: + seq := SequenceForEvent(e) + if seq > 0 { + lastSeq = seq + } + return nil + } + }); err != nil { + if errors.Is(err, ErrPlaybackShutdown) { + em.log.Warn("events playback", "err", err) + } else { + em.log.Error("events playback", "err", err) + } + + // TODO: send an error frame or something? + close(out) + return + } + + // now, start buffering events from the live stream + em.addSubscriber(sub) + + first := <-sub.outgoing + + // run playback again to get us to the events that have started buffering + if err := em.persister.Playback(ctx, lastSeq, func(e *XRPCStreamEvent) error { + seq := SequenceForEvent(e) + if seq > SequenceForEvent(first) { + return ErrCaughtUp + } + + select { + case <-done: + return ErrPlaybackShutdown + case out <- e: + return nil + } + }); err != nil { + if !errors.Is(err, ErrCaughtUp) { + em.log.Error("events playback", "err", err) + + // TODO: send an error frame or something? + close(out) + em.rmSubscriber(sub) + return + } + } + + // now that we are caught up, just copy events from the channel over + for evt := range sub.outgoing { + select { + case out <- evt: + case <-done: + em.rmSubscriber(sub) + return + } + } + }() + + return out, sub.cleanup, nil +} + +func SequenceForEvent(evt *XRPCStreamEvent) int64 { + return evt.Sequence() +} + +func (evt *XRPCStreamEvent) Sequence() int64 { + switch { + case evt == nil: + return -1 + case evt.RepoCommit != nil: + return evt.RepoCommit.Seq + case evt.RepoHandle != nil: + return evt.RepoHandle.Seq + case evt.RepoMigrate != nil: + return evt.RepoMigrate.Seq + case evt.RepoTombstone != nil: + return evt.RepoTombstone.Seq + case evt.RepoIdentity != nil: + return evt.RepoIdentity.Seq + case evt.RepoAccount != nil: + return evt.RepoAccount.Seq + case evt.RepoInfo != nil: + return -1 + case evt.Error != nil: + return -1 + default: + return -1 + } +} + +func (evt *XRPCStreamEvent) GetSequence() (int64, bool) { + switch { + case evt == nil: + return -1, false + case evt.RepoCommit != nil: + return evt.RepoCommit.Seq, true + case evt.RepoHandle != nil: + return evt.RepoHandle.Seq, true + case evt.RepoMigrate != nil: + return evt.RepoMigrate.Seq, true + case evt.RepoTombstone != nil: + return evt.RepoTombstone.Seq, true + case evt.RepoIdentity != nil: + return evt.RepoIdentity.Seq, true + case evt.RepoAccount != nil: + return evt.RepoAccount.Seq, true + case evt.RepoInfo != nil: + return -1, false + case evt.Error != nil: + return -1, false + default: + return -1, false + } +} + +func (em *EventManager) rmSubscriber(sub *Subscriber) { + em.subsLk.Lock() + defer em.subsLk.Unlock() + + for i, s := range em.subs { + if s == sub { + em.subs[i] = em.subs[len(em.subs)-1] + em.subs = em.subs[:len(em.subs)-1] + break + } + } +} + +func (em *EventManager) addSubscriber(sub *Subscriber) { + em.subsLk.Lock() + defer em.subsLk.Unlock() + + em.subs = append(em.subs, sub) +} + +func (em *EventManager) TakeDownRepo(ctx context.Context, user models.Uid) error { + return em.persister.TakeDownRepo(ctx, user) +} diff --git a/cmd/relay/events/metrics.go b/cmd/relay/events/metrics.go new file mode 100644 index 000000000..0788f2f76 --- /dev/null +++ b/cmd/relay/events/metrics.go @@ -0,0 +1,26 @@ +package events + +import ( + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +var eventsFromStreamCounter = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "indigo_repo_stream_events_received_total", + Help: "Total number of events received from the stream", +}, []string{"remote_addr"}) + +var bytesFromStreamCounter = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "indigo_repo_stream_bytes_total", + Help: "Total bytes received from the stream", +}, []string{"remote_addr"}) + +var eventsEnqueued = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "indigo_events_enqueued_for_broadcast_total", + Help: "Total number of events enqueued to broadcast to subscribers", +}, []string{"pool"}) + +var eventsBroadcast = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "indigo_events_broadcast_total", + Help: "Total number of events broadcast to subscribers", +}, []string{"pool"}) diff --git a/cmd/relay/events/pebblepersist.go b/cmd/relay/events/pebblepersist.go new file mode 100644 index 000000000..a4b1039fe --- /dev/null +++ b/cmd/relay/events/pebblepersist.go @@ -0,0 +1,262 @@ +package events + +import ( + "bytes" + "context" + "encoding/binary" + "encoding/hex" + "errors" + "fmt" + "time" + + "github.com/bluesky-social/indigo/cmd/relay/models" + "github.com/cockroachdb/pebble" +) + +type PebblePersist struct { + broadcast func(*XRPCStreamEvent) + db *pebble.DB + + prevSeq int64 + prevSeqExtra uint32 + + cancel func() + + options PebblePersistOptions +} + +type PebblePersistOptions struct { + // path where pebble will create a directory full of files + DbPath string + + // Throw away posts older than some time ago + PersistDuration time.Duration + + // Throw away old posts every so often + GCPeriod time.Duration + + // MaxBytes is what we _try_ to keep disk usage under + MaxBytes uint64 +} + +var DefaultPebblePersistOptions = PebblePersistOptions{ + PersistDuration: time.Minute * 20, + GCPeriod: time.Minute * 5, + MaxBytes: 1024 * 1024 * 1024, // 1 GiB +} + +// Create a new EventPersistence which stores data in pebbledb +// nil opts is ok +func NewPebblePersistance(opts *PebblePersistOptions) (*PebblePersist, error) { + if opts == nil { + opts = &DefaultPebblePersistOptions + } + db, err := pebble.Open(opts.DbPath, &pebble.Options{}) + if err != nil { + return nil, fmt.Errorf("%s: %w", opts.DbPath, err) + } + pp := new(PebblePersist) + pp.options = *opts + pp.db = db + return pp, nil +} + +func setKeySeqMillis(key []byte, seq, millis int64) { + binary.BigEndian.PutUint64(key[:8], uint64(seq)) + binary.BigEndian.PutUint64(key[8:16], uint64(millis)) +} + +func (pp *PebblePersist) Persist(ctx context.Context, e *XRPCStreamEvent) error { + err := e.Preserialize() + if err != nil { + return err + } + blob := e.Preserialized + + seq := e.Sequence() + nowMillis := time.Now().UnixMilli() + + if seq < 0 { + // persist with longer key {prev 8 byte key}{time}{int32 extra counter} + pp.prevSeqExtra++ + var key [20]byte + setKeySeqMillis(key[:], seq, nowMillis) + binary.BigEndian.PutUint32(key[16:], pp.prevSeqExtra) + + err = pp.db.Set(key[:], blob, pebble.Sync) + } else { + pp.prevSeq = seq + pp.prevSeqExtra = 0 + var key [16]byte + setKeySeqMillis(key[:], seq, nowMillis) + + err = pp.db.Set(key[:], blob, pebble.Sync) + } + + if err != nil { + return err + } + pp.broadcast(e) + + return err +} + +func eventFromPebbleIter(iter *pebble.Iterator) (*XRPCStreamEvent, error) { + blob, err := iter.ValueAndErr() + if err != nil { + return nil, err + } + br := bytes.NewReader(blob) + evt := new(XRPCStreamEvent) + err = evt.Deserialize(br) + if err != nil { + return nil, err + } + evt.Preserialized = bytes.Clone(blob) + return evt, nil +} + +func (pp *PebblePersist) Playback(ctx context.Context, since int64, cb func(*XRPCStreamEvent) error) error { + var key [8]byte + binary.BigEndian.PutUint64(key[:], uint64(since)) + + iter, err := pp.db.NewIterWithContext(ctx, &pebble.IterOptions{LowerBound: key[:]}) + if err != nil { + return err + } + defer iter.Close() + + for iter.First(); iter.Valid(); iter.Next() { + evt, err := eventFromPebbleIter(iter) + if err != nil { + return err + } + + err = cb(evt) + if err != nil { + return err + } + } + + return nil +} +func (pp *PebblePersist) TakeDownRepo(ctx context.Context, usr models.Uid) error { + // TODO: implement filter on playback to ignore taken-down-repos? + return nil +} +func (pp *PebblePersist) Flush(context.Context) error { + return pp.db.Flush() +} +func (pp *PebblePersist) Shutdown(context.Context) error { + if pp.cancel != nil { + pp.cancel() + } + err := pp.db.Close() + pp.db = nil + return err +} + +func (pp *PebblePersist) SetEventBroadcaster(broadcast func(*XRPCStreamEvent)) { + pp.broadcast = broadcast +} + +var ErrNoLast = errors.New("no last event") + +func (pp *PebblePersist) GetLast(ctx context.Context) (seq, millis int64, evt *XRPCStreamEvent, err error) { + iter, err := pp.db.NewIterWithContext(ctx, &pebble.IterOptions{}) + if err != nil { + return 0, 0, nil, err + } + ok := iter.Last() + if !ok { + return 0, 0, nil, ErrNoLast + } + evt, err = eventFromPebbleIter(iter) + keyblob := iter.Key() + seq = int64(binary.BigEndian.Uint64(keyblob[:8])) + millis = int64(binary.BigEndian.Uint64(keyblob[8:16])) + return seq, millis, evt, nil +} + +// example; +// ``` +// pp := NewPebblePersistance("/tmp/foo.pebble") +// go pp.GCThread(context.Background(), 48 * time.Hour, 5 * time.Minute) +// ``` +func (pp *PebblePersist) GCThread(ctx context.Context) { + ctx, cancel := context.WithCancel(ctx) + pp.cancel = cancel + ticker := time.NewTicker(pp.options.GCPeriod) + defer ticker.Stop() + for { + select { + case <-ticker.C: + err := pp.GarbageCollect(ctx) + if err != nil { + log.Error("GC err", "err", err) + } + case <-ctx.Done(): + return + } + } +} + +var zeroKey [16]byte +var ffffKey [16]byte + +func init() { + setKeySeqMillis(zeroKey[:], 0, 0) + for i := range ffffKey { + ffffKey[i] = 0xff + } +} + +func (pp *PebblePersist) GarbageCollect(ctx context.Context) error { + nowMillis := time.Now().UnixMilli() + expired := nowMillis - pp.options.PersistDuration.Milliseconds() + iter, err := pp.db.NewIterWithContext(ctx, &pebble.IterOptions{}) + if err != nil { + return err + } + defer iter.Close() + // scan keys to find last expired, then delete range + var seq int64 = int64(-1) + var lastKeyTime int64 + for iter.First(); iter.Valid(); iter.Next() { + keyblob := iter.Key() + + keyTime := int64(binary.BigEndian.Uint64(keyblob[8:16])) + if keyTime <= expired { + lastKeyTime = keyTime + seq = int64(binary.BigEndian.Uint64(keyblob[:8])) + } else { + break + } + } + + // TODO: use pp.options.MaxBytes + + sizeBefore, _ := pp.db.EstimateDiskUsage(zeroKey[:], ffffKey[:]) + if seq == -1 { + // nothing to delete + log.Info("pebble gc nop", "size", sizeBefore) + return nil + } + var key [16]byte + setKeySeqMillis(key[:], seq, lastKeyTime) + log.Info("pebble gc start", "to", hex.EncodeToString(key[:])) + err = pp.db.DeleteRange(zeroKey[:], key[:], pebble.Sync) + if err != nil { + return err + } + sizeAfter, _ := pp.db.EstimateDiskUsage(zeroKey[:], ffffKey[:]) + log.Info("pebble gc", "before", sizeBefore, "after", sizeAfter) + start := time.Now() + err = pp.db.Compact(zeroKey[:], key[:], true) + if err != nil { + log.Warn("pebble gc compact", "err", err) + } + dt := time.Since(start) + log.Info("pebble gc compact ok", "dt", dt) + return nil +} diff --git a/cmd/relay/events/persist.go b/cmd/relay/events/persist.go new file mode 100644 index 000000000..82d57f8fe --- /dev/null +++ b/cmd/relay/events/persist.go @@ -0,0 +1,99 @@ +package events + +import ( + "context" + "fmt" + "sync" + + "github.com/bluesky-social/indigo/cmd/relay/models" +) + +// Note that this interface looks generic, but some persisters might only work with RepoAppend or LabelLabels +type EventPersistence interface { + Persist(ctx context.Context, e *XRPCStreamEvent) error + Playback(ctx context.Context, since int64, cb func(*XRPCStreamEvent) error) error + TakeDownRepo(ctx context.Context, usr models.Uid) error + Flush(context.Context) error + Shutdown(context.Context) error + + SetEventBroadcaster(func(*XRPCStreamEvent)) +} + +// MemPersister is the most naive implementation of event persistence +// This EventPersistence option works fine with all event types +// ill do better later +type MemPersister struct { + buf []*XRPCStreamEvent + lk sync.Mutex + seq int64 + + broadcast func(*XRPCStreamEvent) +} + +func NewMemPersister() *MemPersister { + return &MemPersister{} +} + +func (mp *MemPersister) Persist(ctx context.Context, e *XRPCStreamEvent) error { + mp.lk.Lock() + defer mp.lk.Unlock() + mp.seq++ + switch { + case e.RepoCommit != nil: + e.RepoCommit.Seq = mp.seq + case e.RepoHandle != nil: + e.RepoHandle.Seq = mp.seq + case e.RepoIdentity != nil: + e.RepoIdentity.Seq = mp.seq + case e.RepoAccount != nil: + e.RepoAccount.Seq = mp.seq + case e.RepoMigrate != nil: + e.RepoMigrate.Seq = mp.seq + case e.RepoTombstone != nil: + e.RepoTombstone.Seq = mp.seq + case e.LabelLabels != nil: + e.LabelLabels.Seq = mp.seq + default: + panic("no event in persist call") + } + mp.buf = append(mp.buf, e) + + mp.broadcast(e) + + return nil +} + +func (mp *MemPersister) Playback(ctx context.Context, since int64, cb func(*XRPCStreamEvent) error) error { + mp.lk.Lock() + l := len(mp.buf) + mp.lk.Unlock() + + if since >= int64(l) { + return nil + } + + // TODO: abusing the fact that buf[0].seq is currently always 1 + for _, e := range mp.buf[since:l] { + if err := cb(e); err != nil { + return err + } + } + + return nil +} + +func (mp *MemPersister) TakeDownRepo(ctx context.Context, uid models.Uid) error { + return fmt.Errorf("repo takedowns not currently supported by memory persister, test usage only") +} + +func (mp *MemPersister) Flush(ctx context.Context) error { + return nil +} + +func (mp *MemPersister) SetEventBroadcaster(brc func(*XRPCStreamEvent)) { + mp.broadcast = brc +} + +func (mp *MemPersister) Shutdown(context.Context) error { + return nil +} diff --git a/cmd/relay/events/schedulers/metrics.go b/cmd/relay/events/schedulers/metrics.go new file mode 100644 index 000000000..4b3940cac --- /dev/null +++ b/cmd/relay/events/schedulers/metrics.go @@ -0,0 +1,26 @@ +package schedulers + +import ( + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +var WorkItemsAdded = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "indigo_scheduler_work_items_added_total", + Help: "Total number of work items added to the consumer pool", +}, []string{"pool", "scheduler_type"}) + +var WorkItemsProcessed = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "indigo_scheduler_work_items_processed_total", + Help: "Total number of work items processed by the consumer pool", +}, []string{"pool", "scheduler_type"}) + +var WorkItemsActive = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "indigo_scheduler_work_items_active_total", + Help: "Total number of work items passed into a worker", +}, []string{"pool", "scheduler_type"}) + +var WorkersActive = promauto.NewGaugeVec(prometheus.GaugeOpts{ + Name: "indigo_scheduler_workers_active", + Help: "Number of workers currently active", +}, []string{"pool", "scheduler_type"}) diff --git a/cmd/relay/events/schedulers/parallel/parallel.go b/cmd/relay/events/schedulers/parallel/parallel.go new file mode 100644 index 000000000..649aed2d6 --- /dev/null +++ b/cmd/relay/events/schedulers/parallel/parallel.go @@ -0,0 +1,148 @@ +package parallel + +import ( + "context" + "log/slog" + "sync" + + "github.com/bluesky-social/indigo/cmd/relay/events" + "github.com/bluesky-social/indigo/events/schedulers" + + "github.com/prometheus/client_golang/prometheus" +) + +// Scheduler is a parallel scheduler that will run work on a fixed number of workers +type Scheduler struct { + maxConcurrency int + maxQueue int + + do func(context.Context, *events.XRPCStreamEvent) error + + feeder chan *consumerTask + out chan struct{} + + lk sync.Mutex + active map[string][]*consumerTask + + ident string + + // metrics + itemsAdded prometheus.Counter + itemsProcessed prometheus.Counter + itemsActive prometheus.Counter + workesActive prometheus.Gauge + + log *slog.Logger +} + +func NewScheduler(maxC, maxQ int, ident string, do func(context.Context, *events.XRPCStreamEvent) error) *Scheduler { + p := &Scheduler{ + maxConcurrency: maxC, + maxQueue: maxQ, + + do: do, + + feeder: make(chan *consumerTask), + active: make(map[string][]*consumerTask), + out: make(chan struct{}), + + ident: ident, + + itemsAdded: schedulers.WorkItemsAdded.WithLabelValues(ident, "parallel"), + itemsProcessed: schedulers.WorkItemsProcessed.WithLabelValues(ident, "parallel"), + itemsActive: schedulers.WorkItemsActive.WithLabelValues(ident, "parallel"), + workesActive: schedulers.WorkersActive.WithLabelValues(ident, "parallel"), + + log: slog.Default().With("system", "parallel-scheduler"), + } + + for i := 0; i < maxC; i++ { + go p.worker() + } + + p.workesActive.Set(float64(maxC)) + + return p +} + +func (p *Scheduler) Shutdown() { + p.log.Info("shutting down parallel scheduler", "ident", p.ident) + + for i := 0; i < p.maxConcurrency; i++ { + p.feeder <- &consumerTask{ + control: "stop", + } + } + + close(p.feeder) + + for i := 0; i < p.maxConcurrency; i++ { + <-p.out + } + + p.log.Info("parallel scheduler shutdown complete") +} + +type consumerTask struct { + repo string + val *events.XRPCStreamEvent + control string +} + +func (p *Scheduler) AddWork(ctx context.Context, repo string, val *events.XRPCStreamEvent) error { + p.itemsAdded.Inc() + t := &consumerTask{ + repo: repo, + val: val, + } + p.lk.Lock() + + a, ok := p.active[repo] + if ok { + p.active[repo] = append(a, t) + p.lk.Unlock() + return nil + } + + p.active[repo] = []*consumerTask{} + p.lk.Unlock() + + select { + case p.feeder <- t: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +func (p *Scheduler) worker() { + for work := range p.feeder { + for work != nil { + if work.control == "stop" { + p.out <- struct{}{} + return + } + + p.itemsActive.Inc() + if err := p.do(context.TODO(), work.val); err != nil { + p.log.Error("event handler failed", "err", err) + } + p.itemsProcessed.Inc() + + p.lk.Lock() + rem, ok := p.active[work.repo] + if !ok { + p.log.Error("should always have an 'active' entry if a worker is processing a job") + } + + if len(rem) == 0 { + delete(p.active, work.repo) + work = nil + } else { + work = rem[0] + p.active[work.repo] = rem[1:] + } + p.lk.Unlock() + } + } +} diff --git a/cmd/relay/events/schedulers/scheduler.go b/cmd/relay/events/schedulers/scheduler.go new file mode 100644 index 000000000..9185832f5 --- /dev/null +++ b/cmd/relay/events/schedulers/scheduler.go @@ -0,0 +1 @@ +package schedulers diff --git a/cmd/relay/events/yolopersist.go b/cmd/relay/events/yolopersist.go new file mode 100644 index 000000000..fe4310ef4 --- /dev/null +++ b/cmd/relay/events/yolopersist.go @@ -0,0 +1,69 @@ +package events + +import ( + "context" + "fmt" + "sync" + + "github.com/bluesky-social/indigo/cmd/relay/models" +) + +// YoloPersister is used for benchmarking, it has no persistence, it just emits events and forgets them +type YoloPersister struct { + lk sync.Mutex + seq int64 + + broadcast func(*XRPCStreamEvent) +} + +func NewYoloPersister() *YoloPersister { + return &YoloPersister{} +} + +func (yp *YoloPersister) Persist(ctx context.Context, e *XRPCStreamEvent) error { + yp.lk.Lock() + defer yp.lk.Unlock() + yp.seq++ + switch { + case e.RepoCommit != nil: + e.RepoCommit.Seq = yp.seq + case e.RepoHandle != nil: + e.RepoHandle.Seq = yp.seq + case e.RepoIdentity != nil: + e.RepoIdentity.Seq = yp.seq + case e.RepoAccount != nil: + e.RepoAccount.Seq = yp.seq + case e.RepoMigrate != nil: + e.RepoMigrate.Seq = yp.seq + case e.RepoTombstone != nil: + e.RepoTombstone.Seq = yp.seq + case e.LabelLabels != nil: + e.LabelLabels.Seq = yp.seq + default: + panic("no event in persist call") + } + + yp.broadcast(e) + + return nil +} + +func (mp *YoloPersister) Playback(ctx context.Context, since int64, cb func(*XRPCStreamEvent) error) error { + return fmt.Errorf("playback not supported by yolo persister, test usage only") +} + +func (yp *YoloPersister) TakeDownRepo(ctx context.Context, uid models.Uid) error { + return fmt.Errorf("repo takedowns not currently supported by memory persister, test usage only") +} + +func (yp *YoloPersister) SetEventBroadcaster(brc func(*XRPCStreamEvent)) { + yp.broadcast = brc +} + +func (yp *YoloPersister) Flush(ctx context.Context) error { + return nil +} + +func (yp *YoloPersister) Shutdown(ctx context.Context) error { + return nil +} diff --git a/cmd/relay/main.go b/cmd/relay/main.go new file mode 100644 index 000000000..bdd27c669 --- /dev/null +++ b/cmd/relay/main.go @@ -0,0 +1,494 @@ +package main + +import ( + "context" + "errors" + "fmt" + "github.com/bluesky-social/indigo/atproto/identity" + "gorm.io/gorm" + "io" + "log/slog" + _ "net/http/pprof" + "net/url" + "os" + "os/signal" + "path/filepath" + "strconv" + "strings" + "syscall" + "time" + + libbgs "github.com/bluesky-social/indigo/cmd/relay/bgs" + "github.com/bluesky-social/indigo/cmd/relay/events" + "github.com/bluesky-social/indigo/cmd/relay/repomgr" + "github.com/bluesky-social/indigo/util" + "github.com/bluesky-social/indigo/util/cliutil" + "github.com/bluesky-social/indigo/xrpc" + + _ "github.com/joho/godotenv/autoload" + _ "go.uber.org/automaxprocs" + + "github.com/carlmjohnson/versioninfo" + "github.com/urfave/cli/v2" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/exporters/jaeger" + "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp" + "go.opentelemetry.io/otel/sdk/resource" + tracesdk "go.opentelemetry.io/otel/sdk/trace" + semconv "go.opentelemetry.io/otel/semconv/v1.4.0" + "gorm.io/plugin/opentelemetry/tracing" +) + +func init() { + // control log level using, eg, GOLOG_LOG_LEVEL=debug + //logging.SetAllLoggers(logging.LevelDebug) +} + +func main() { + if err := run(os.Args); err != nil { + slog.Error(err.Error()) + os.Exit(1) + } +} + +func run(args []string) error { + + app := cli.App{ + Name: "relay", + Usage: "atproto Relay daemon", + Version: versioninfo.Short(), + } + + app.Flags = []cli.Flag{ + &cli.BoolFlag{ + Name: "jaeger", + }, + &cli.StringFlag{ + Name: "db-url", + Usage: "database connection string for BGS database", + Value: "sqlite://./data/bigsky/bgs.sqlite", + EnvVars: []string{"DATABASE_URL"}, + }, + &cli.BoolFlag{ + Name: "db-tracing", + }, + &cli.StringFlag{ + Name: "plc-host", + Usage: "method, hostname, and port of PLC registry", + Value: "https://plc.directory", + EnvVars: []string{"ATP_PLC_HOST"}, + }, + &cli.BoolFlag{ + Name: "crawl-insecure-ws", + Usage: "when connecting to PDS instances, use ws:// instead of wss://", + }, + &cli.StringFlag{ + Name: "api-listen", + Value: ":2470", + EnvVars: []string{"RELAY_API_LISTEN"}, + }, + &cli.StringFlag{ + Name: "metrics-listen", + Value: ":2471", + EnvVars: []string{"RELAY_METRICS_LISTEN", "BGS_METRICS_LISTEN"}, + }, + &cli.StringFlag{ + Name: "disk-persister-dir", + Usage: "set directory for disk persister (implicitly enables disk persister)", + EnvVars: []string{"RELAY_PERSISTER_DIR"}, + }, + &cli.StringFlag{ + Name: "admin-key", + EnvVars: []string{"RELAY_ADMIN_KEY", "BGS_ADMIN_KEY"}, + }, + &cli.IntFlag{ + Name: "max-metadb-connections", + EnvVars: []string{"MAX_METADB_CONNECTIONS"}, + Value: 40, + }, + &cli.StringFlag{ + Name: "resolve-address", + EnvVars: []string{"RESOLVE_ADDRESS"}, + Value: "1.1.1.1:53", + }, + &cli.BoolFlag{ + Name: "force-dns-udp", + EnvVars: []string{"FORCE_DNS_UDP"}, + }, + &cli.IntFlag{ + Name: "max-fetch-concurrency", + Value: 100, + EnvVars: []string{"MAX_FETCH_CONCURRENCY"}, + }, + &cli.StringFlag{ + Name: "env", + Value: "dev", + EnvVars: []string{"ENVIRONMENT"}, + Usage: "declared hosting environment (prod, qa, etc); used in metrics", + }, + &cli.StringFlag{ + Name: "otel-exporter-otlp-endpoint", + EnvVars: []string{"OTEL_EXPORTER_OTLP_ENDPOINT"}, + }, + &cli.StringFlag{ + Name: "bsky-social-rate-limit-skip", + EnvVars: []string{"BSKY_SOCIAL_RATE_LIMIT_SKIP"}, + Usage: "ratelimit bypass secret token for *.bsky.social domains", + }, + &cli.IntFlag{ + Name: "default-repo-limit", + Value: 100, + EnvVars: []string{"RELAY_DEFAULT_REPO_LIMIT"}, + }, + &cli.IntFlag{ + Name: "concurrency-per-pds", + EnvVars: []string{"RELAY_CONCURRENCY_PER_PDS"}, + Value: 100, + }, + &cli.IntFlag{ + Name: "max-queue-per-pds", + EnvVars: []string{"RELAY_MAX_QUEUE_PER_PDS"}, + Value: 1_000, + }, + &cli.IntFlag{ + Name: "did-cache-size", + Usage: "in-process cache by number of Did documents", + EnvVars: []string{"RELAY_DID_CACHE_SIZE"}, + Value: 5_000_000, + }, + &cli.DurationFlag{ + Name: "event-playback-ttl", + Usage: "time to live for event playback buffering (only applies to disk persister)", + EnvVars: []string{"RELAY_EVENT_PLAYBACK_TTL"}, + Value: 72 * time.Hour, + }, + &cli.StringSliceFlag{ + Name: "next-crawler", + Usage: "forward POST requestCrawl to this url, should be machine root url and not xrpc/requestCrawl, comma separated list", + EnvVars: []string{"RELAY_NEXT_CRAWLER"}, + }, + &cli.StringFlag{ + Name: "trace-induction", + Usage: "file path to log debug trace stuff about induction firehose", + EnvVars: []string{"RELAY_TRACE_INDUCTION"}, + }, + &cli.BoolFlag{ + Name: "time-seq", + EnvVars: []string{"RELAY_TIME_SEQUENCE"}, + Value: false, + Usage: "make outbound firehose sequence number approximately unix microseconds", + }, + } + + app.Action = runBigsky + return app.Run(os.Args) +} + +func setupOTEL(cctx *cli.Context) error { + + env := cctx.String("env") + if env == "" { + env = "dev" + } + if cctx.Bool("jaeger") { + jaegerUrl := "http://localhost:14268/api/traces" + exp, err := jaeger.New(jaeger.WithCollectorEndpoint(jaeger.WithEndpoint(jaegerUrl))) + if err != nil { + return err + } + tp := tracesdk.NewTracerProvider( + // Always be sure to batch in production. + tracesdk.WithBatcher(exp), + // Record information about this application in a Resource. + tracesdk.WithResource(resource.NewWithAttributes( + semconv.SchemaURL, + semconv.ServiceNameKey.String("bgs"), + attribute.String("env", env), // DataDog + attribute.String("environment", env), // Others + attribute.Int64("ID", 1), + )), + ) + + otel.SetTracerProvider(tp) + } + + // Enable OTLP HTTP exporter + // For relevant environment variables: + // https://pkg.go.dev/go.opentelemetry.io/otel/exporters/otlp/otlptrace#readme-environment-variables + // At a minimum, you need to set + // OTEL_EXPORTER_OTLP_ENDPOINT=http://localhost:4318 + if ep := cctx.String("otel-exporter-otlp-endpoint"); ep != "" { + slog.Info("setting up trace exporter", "endpoint", ep) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + exp, err := otlptracehttp.New(ctx) + if err != nil { + slog.Error("failed to create trace exporter", "error", err) + os.Exit(1) + } + defer func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + if err := exp.Shutdown(ctx); err != nil { + slog.Error("failed to shutdown trace exporter", "error", err) + } + }() + + tp := tracesdk.NewTracerProvider( + tracesdk.WithBatcher(exp), + tracesdk.WithResource(resource.NewWithAttributes( + semconv.SchemaURL, + semconv.ServiceNameKey.String("bgs"), + attribute.String("env", env), // DataDog + attribute.String("environment", env), // Others + attribute.Int64("ID", 1), + )), + ) + otel.SetTracerProvider(tp) + } + + return nil +} + +func runBigsky(cctx *cli.Context) error { + // Trap SIGINT to trigger a shutdown. + signals := make(chan os.Signal, 1) + signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM) + + logger, logWriter, err := cliutil.SetupSlog(cliutil.LogOptions{}) + if err != nil { + return err + } + + var inductionTraceLog *slog.Logger + + if cctx.IsSet("trace-induction") { + traceFname := cctx.String("trace-induction") + traceFout, err := os.OpenFile(traceFname, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + return fmt.Errorf("%s: could not open trace file: %w", traceFname, err) + } + defer traceFout.Close() + if traceFname != "" { + inductionTraceLog = slog.New(slog.NewJSONHandler(traceFout, &slog.HandlerOptions{Level: slog.LevelDebug})) + } + } else { + inductionTraceLog = slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{Level: slog.Level(999)})) + } + + // start observability/tracing (OTEL and jaeger) + if err := setupOTEL(cctx); err != nil { + return err + } + + dburl := cctx.String("db-url") + logger.Info("setting up main database", "url", dburl) + db, err := cliutil.SetupDatabase(dburl, cctx.Int("max-metadb-connections")) + if err != nil { + return err + } + if cctx.Bool("db-tracing") { + if err := db.Use(tracing.NewPlugin()); err != nil { + return err + } + } + if err := db.AutoMigrate(RelaySetting{}); err != nil { + panic(err) + } + + // TODO: add shared external cache + baseDir := identity.BaseDirectory{ + SkipHandleVerification: true, + SkipDNSDomainSuffixes: []string{".bsky.social"}, + TryAuthoritativeDNS: true, + } + cacheDir := identity.NewCacheDirectory(&baseDir, cctx.Int("did-cache-size"), time.Hour*24, time.Minute*2, time.Minute*5) + + repoman := repomgr.NewRepoManager(&cacheDir, inductionTraceLog) + + var persister events.EventPersistence + + dpd := cctx.String("disk-persister-dir") + if dpd == "" { + logger.Info("empty disk-persister-dir, use current working directory") + cwd, err := os.Getwd() + if err != nil { + return err + } + dpd = filepath.Join(cwd, "relay-persist") + } + logger.Info("setting up disk persister", "dir", dpd) + + pOpts := events.DefaultDiskPersistOptions() + pOpts.Retention = cctx.Duration("event-playback-ttl") + pOpts.TimeSequence = cctx.Bool("time-seq") + + // ensure that time-ish sequence stays consistent within a server context + storedTimeSeq, hadStoredTimeSeq, err := getRelaySettingBool(db, "time-seq") + if err != nil { + return err + } + if !hadStoredTimeSeq { + if err := setRelaySettingBool(db, "time-seq", pOpts.TimeSequence); err != nil { + return err + } + } else { + if pOpts.TimeSequence != storedTimeSeq { + return fmt.Errorf("time-seq stored as %v but param/env set as %v", storedTimeSeq, pOpts.TimeSequence) + } + } + + dp, err := events.NewDiskPersistence(dpd, "", db, pOpts) + if err != nil { + return fmt.Errorf("setting up disk persister: %w", err) + } + persister = dp + + evtman := events.NewEventManager(persister) + + repoman.SetEventManager(evtman) + + ratelimitBypass := cctx.String("bsky-social-rate-limit-skip") + + logger.Info("constructing bgs") + bgsConfig := libbgs.DefaultBGSConfig() + bgsConfig.SSL = !cctx.Bool("crawl-insecure-ws") + bgsConfig.ConcurrencyPerPDS = cctx.Int64("concurrency-per-pds") + bgsConfig.MaxQueuePerPDS = cctx.Int64("max-queue-per-pds") + bgsConfig.DefaultRepoLimit = cctx.Int64("default-repo-limit") + bgsConfig.ApplyPDSClientSettings = makePdsClientSetup(ratelimitBypass) + bgsConfig.InductionTraceLog = inductionTraceLog + nextCrawlers := cctx.StringSlice("next-crawler") + if len(nextCrawlers) != 0 { + nextCrawlerUrls := make([]*url.URL, len(nextCrawlers)) + for i, tu := range nextCrawlers { + var err error + nextCrawlerUrls[i], err = url.Parse(tu) + if err != nil { + return fmt.Errorf("failed to parse next-crawler url: %w", err) + } + logger.Info("configuring relay for requestCrawl", "host", nextCrawlerUrls[i]) + } + bgsConfig.NextCrawlers = nextCrawlerUrls + } + bgs, err := libbgs.NewBGS(db, repoman, evtman, &cacheDir, bgsConfig) + if err != nil { + return err + } + dp.SetUidSource(bgs) + + if tok := cctx.String("admin-key"); tok != "" { + if err := bgs.CreateAdminToken(tok); err != nil { + return fmt.Errorf("failed to set up admin token: %w", err) + } + } + + // set up metrics endpoint + go func() { + if err := bgs.StartMetrics(cctx.String("metrics-listen")); err != nil { + logger.Error("failed to start metrics endpoint", "err", err) + os.Exit(1) + } + }() + + bgsErr := make(chan error, 1) + + go func() { + err := bgs.Start(cctx.String("api-listen"), logWriter) + bgsErr <- err + }() + + logger.Info("startup complete") + select { + case <-signals: + logger.Info("received shutdown signal") + errs := bgs.Shutdown() + for err := range errs { + logger.Error("error during BGS shutdown", "err", err) + } + case err := <-bgsErr: + if err != nil { + logger.Error("error during BGS startup", "err", err) + } + logger.Info("shutting down") + errs := bgs.Shutdown() + for err := range errs { + logger.Error("error during BGS shutdown", "err", err) + } + } + + logger.Info("shutdown complete") + + return nil +} + +func makePdsClientSetup(ratelimitBypass string) func(c *xrpc.Client) { + return func(c *xrpc.Client) { + if c.Client == nil { + c.Client = util.RobustHTTPClient() + } + if strings.HasSuffix(c.Host, ".bsky.network") { + c.Client.Timeout = time.Minute * 30 + if ratelimitBypass != "" { + c.Headers = map[string]string{ + "x-ratelimit-bypass": ratelimitBypass, + } + } + } else { + // Generic PDS timeout + c.Client.Timeout = time.Minute * 1 + } + } +} + +// RelaySetting is a gorm model +type RelaySetting struct { + Name string `gorm:"primarykey"` + Value string +} + +func getRelaySetting(db *gorm.DB, name string) (value string, found bool, err error) { + var setting RelaySetting + dbResult := db.First(&setting, "name = ?", name) + if errors.Is(dbResult.Error, gorm.ErrRecordNotFound) { + return "", false, nil + } + if dbResult.Error != nil { + return "", false, dbResult.Error + } + return setting.Value, true, nil +} + +func setRelaySetting(db *gorm.DB, name string, value string) error { + return db.Transaction(func(tx *gorm.DB) error { + var setting RelaySetting + found := tx.First(&setting, "name = ?", name) + if errors.Is(found.Error, gorm.ErrRecordNotFound) { + // ok! create it + setting.Name = name + setting.Value = value + return tx.Create(&setting).Error + } else if found.Error != nil { + return found.Error + } + setting.Value = value + return tx.Save(&setting).Error + }) +} + +func getRelaySettingBool(db *gorm.DB, name string) (value bool, found bool, err error) { + strval, found, err := getRelaySetting(db, name) + if err != nil || !found { + return false, found, err + } + value, err = strconv.ParseBool(strval) + if err != nil { + return false, false, err + } + return value, true, nil +} +func setRelaySettingBool(db *gorm.DB, name string, value bool) error { + return setRelaySetting(db, name, strconv.FormatBool(value)) +} diff --git a/cmd/relay/models/models.go b/cmd/relay/models/models.go new file mode 100644 index 000000000..69864fcfa --- /dev/null +++ b/cmd/relay/models/models.go @@ -0,0 +1,84 @@ +package models + +import ( + "database/sql/driver" + "encoding/json" + "fmt" + "github.com/ipfs/go-cid" + "gorm.io/gorm" +) + +type Uid uint64 + +type DbCID struct { + CID cid.Cid +} + +func (dbc *DbCID) Scan(v interface{}) error { + b, ok := v.([]byte) + if !ok { + return fmt.Errorf("dbcids must get bytes!") + } + + if len(b) == 0 { + return nil + } + + c, err := cid.Cast(b) + if err != nil { + return err + } + + dbc.CID = c + return nil +} + +func (dbc DbCID) Value() (driver.Value, error) { + if !dbc.CID.Defined() { + return nil, fmt.Errorf("cannot serialize undefined cid to database") + } + return dbc.CID.Bytes(), nil +} + +func (dbc DbCID) MarshalJSON() ([]byte, error) { + return json.Marshal(dbc.CID.String()) +} + +func (dbc *DbCID) UnmarshalJSON(b []byte) error { + var s string + if err := json.Unmarshal(b, &s); err != nil { + return err + } + + c, err := cid.Decode(s) + if err != nil { + return err + } + + dbc.CID = c + return nil +} + +func (dbc *DbCID) GormDataType() string { + return "bytes" +} + +type PDS struct { + gorm.Model + + Host string `gorm:"unique"` + //Did string + SSL bool + Cursor int64 + Registered bool + Blocked bool + + RateLimit float64 + //CrawlRateLimit float64 + + RepoCount int64 + RepoLimit int64 + + HourlyEventLimit int64 + DailyEventLimit int64 +} diff --git a/cmd/relay/repomgr/metrics.go b/cmd/relay/repomgr/metrics.go new file mode 100644 index 000000000..f58c86def --- /dev/null +++ b/cmd/relay/repomgr/metrics.go @@ -0,0 +1,29 @@ +package repomgr + +import ( + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +var commitVerifyStarts = promauto.NewCounter(prometheus.CounterOpts{ + Name: "repomgr_commit_verify_starts", +}) + +var commitVerifyWarnings = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "repomgr_commit_verify_warnings", +}, []string{"host", "warn"}) + +// verify error and short code for why +var commitVerifyErrors = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "repomgr_commit_verify_errors", +}, []string{"host", "err"}) + +// ok and *fully verified* +var commitVerifyOk = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "repomgr_commit_verify_ok", +}, []string{"host"}) + +// it's ok, but... {old protocol, no previous root cid, ...} +var commitVerifyOkish = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "repomgr_commit_verify_okish", +}, []string{"host", "but"}) diff --git a/cmd/relay/repomgr/repomgr.go b/cmd/relay/repomgr/repomgr.go new file mode 100644 index 000000000..feebf0067 --- /dev/null +++ b/cmd/relay/repomgr/repomgr.go @@ -0,0 +1,398 @@ +package repomgr + +import ( + "bytes" + "context" + "errors" + "fmt" + "log/slog" + "sync" + "sync/atomic" + "time" + + atproto "github.com/bluesky-social/indigo/api/atproto" + "github.com/bluesky-social/indigo/atproto/identity" + atrepo "github.com/bluesky-social/indigo/atproto/repo" + "github.com/bluesky-social/indigo/atproto/syntax" + "github.com/bluesky-social/indigo/cmd/relay/events" + "github.com/bluesky-social/indigo/cmd/relay/models" + "github.com/ipfs/go-cid" + "go.opentelemetry.io/otel" +) + +const defaultMaxRevFuture = time.Hour + +func NewRepoManager(directory identity.Directory, inductionTraceLog *slog.Logger) *RepoManager { + maxRevFuture := defaultMaxRevFuture // TODO: configurable + ErrRevTooFarFuture := fmt.Errorf("new rev is > %s in the future", maxRevFuture) + + return &RepoManager{ + userLocks: make(map[models.Uid]*userLock), + log: slog.Default().With("system", "repomgr"), + inductionTraceLog: inductionTraceLog, + directory: directory, + + maxRevFuture: maxRevFuture, + ErrRevTooFarFuture: ErrRevTooFarFuture, + AllowSignatureNotFound: true, // TODO: configurable + } +} + +func (rm *RepoManager) SetEventManager(events *events.EventManager) { + rm.events = events +} + +// RepoManager is a poorly defined chunk of code +// TODO: RepoManager should probably merge with what calls it or what it calls; probably move HandleCommit into bgs.go +type RepoManager struct { + lklk sync.Mutex + userLocks map[models.Uid]*userLock + + events *events.EventManager + + log *slog.Logger + inductionTraceLog *slog.Logger + + directory identity.Directory + + maxRevFuture time.Duration + ErrRevTooFarFuture error + + // AllowSignatureNotFound enables counting messages without findable public key to pass through with a warning counter + AllowSignatureNotFound bool +} + +type NextCommitHandler interface { + HandleCommit(ctx context.Context, host *models.PDS, uid models.Uid, did string, commit *atproto.SyncSubscribeRepos_Commit) error +} + +type userLock struct { + lk sync.Mutex + waiters atomic.Int32 +} + +// lockUser re-serializes access per-user after events may have been fanned out to many worker threads by events/schedulers/parallel +func (rm *RepoManager) lockUser(ctx context.Context, user models.Uid) func() { + ctx, span := otel.Tracer("repoman").Start(ctx, "userLock") + defer span.End() + + rm.lklk.Lock() + + ulk, ok := rm.userLocks[user] + if !ok { + ulk = &userLock{} + rm.userLocks[user] = ulk + } + + ulk.waiters.Add(1) + + rm.lklk.Unlock() + + ulk.lk.Lock() + + return func() { + rm.lklk.Lock() + defer rm.lklk.Unlock() + + ulk.lk.Unlock() + + nv := ulk.waiters.Add(-1) + + if nv == 0 { + delete(rm.userLocks, user) + } + } +} + +type IUser interface { + GetUid() models.Uid + GetDid() string +} + +type UserPrev interface { + GetCid() cid.Cid + GetRev() syntax.TID +} + +func (rm *RepoManager) HandleCommit(ctx context.Context, host *models.PDS, user IUser, commit *atproto.SyncSubscribeRepos_Commit, prevRoot UserPrev) (newRoot *cid.Cid, err error) { + uid := user.GetUid() + unlock := rm.lockUser(ctx, uid) + defer unlock() + repoFragment, err := rm.VerifyCommitMessage(ctx, host, commit, prevRoot) + if err != nil { + return nil, err + } + newRootCid, err := repoFragment.MST.RootCID() + if err != nil { + return nil, err + } + if rm.events != nil { + xe := &events.XRPCStreamEvent{ + RepoCommit: commit, + PrivUid: uid, + } + err = rm.events.AddEvent(ctx, xe) + if err != nil { + rm.log.Error("events handle commit", "err", err) + } + } + return newRootCid, nil +} + +var ErrNewRevBeforePrevRev = errors.New("new rev is before previous rev") + +func (rm *RepoManager) VerifyCommitMessage(ctx context.Context, host *models.PDS, msg *atproto.SyncSubscribeRepos_Commit, prevRoot UserPrev) (*atrepo.Repo, error) { + hostname := host.Host + hasWarning := false + commitVerifyStarts.Inc() + logger := slog.Default().With("did", msg.Repo, "rev", msg.Rev, "seq", msg.Seq, "time", msg.Time) + + did, err := syntax.ParseDID(msg.Repo) + if err != nil { + commitVerifyErrors.WithLabelValues(hostname, "did").Inc() + return nil, err + } + rev, err := syntax.ParseTID(msg.Rev) + if err != nil { + commitVerifyErrors.WithLabelValues(hostname, "tid").Inc() + return nil, err + } + if prevRoot != nil { + prevRev := prevRoot.GetRev() + curTime := rev.Time() + prevTime := prevRev.Time() + if curTime.Before(prevTime) { + commitVerifyErrors.WithLabelValues(hostname, "revb").Inc() + dt := prevTime.Sub(curTime) + return nil, fmt.Errorf("new rev is before previous rev by %s", dt.String()) + } + } + if rev.Time().After(time.Now().Add(rm.maxRevFuture)) { + commitVerifyErrors.WithLabelValues(hostname, "revf").Inc() + return nil, rm.ErrRevTooFarFuture + } + _, err = syntax.ParseDatetime(msg.Time) + if err != nil { + commitVerifyErrors.WithLabelValues(hostname, "time").Inc() + return nil, err + } + + if msg.TooBig { + //logger.Warn("event with tooBig flag set") + commitVerifyWarnings.WithLabelValues(hostname, "big").Inc() + rm.inductionTraceLog.Warn("commit tooBig", "seq", msg.Seq, "pdsHost", host.Host, "repo", msg.Repo) + hasWarning = true + } + if msg.Rebase { + //logger.Warn("event with rebase flag set") + commitVerifyWarnings.WithLabelValues(hostname, "reb").Inc() + rm.inductionTraceLog.Warn("commit rebase", "seq", msg.Seq, "pdsHost", host.Host, "repo", msg.Repo) + hasWarning = true + } + + commit, repoFragment, err := atrepo.LoadFromCAR(ctx, bytes.NewReader([]byte(msg.Blocks))) + if err != nil { + commitVerifyErrors.WithLabelValues(hostname, "car").Inc() + return nil, err + } + + if commit.Rev != rev.String() { + commitVerifyErrors.WithLabelValues(hostname, "rev").Inc() + return nil, fmt.Errorf("rev did not match commit") + } + if commit.DID != did.String() { + commitVerifyErrors.WithLabelValues(hostname, "did2").Inc() + return nil, fmt.Errorf("rev did not match commit") + } + + err = rm.VerifyCommitSignature(ctx, commit, hostname, &hasWarning) + if err != nil { + // signature errors are metrics counted inside VerifyCommitSignature() + return nil, err + } + + // load out all the records + for _, op := range msg.Ops { + if (op.Action == "create" || op.Action == "update") && op.Cid != nil { + c := (*cid.Cid)(op.Cid) + nsid, rkey, err := syntax.ParseRepoPath(op.Path) + if err != nil { + commitVerifyErrors.WithLabelValues(hostname, "opp").Inc() + return nil, fmt.Errorf("invalid repo path in ops list: %w", err) + } + val, err := repoFragment.GetRecordCID(ctx, nsid, rkey) + if err != nil { + commitVerifyErrors.WithLabelValues(hostname, "rcid").Inc() + return nil, err + } + if *c != *val { + commitVerifyErrors.WithLabelValues(hostname, "opc").Inc() + return nil, fmt.Errorf("record op doesn't match MST tree value") + } + _, err = repoFragment.GetRecordBytes(ctx, nsid, rkey) + if err != nil { + commitVerifyErrors.WithLabelValues(hostname, "rec").Inc() + return nil, err + } + } + } + + // TODO: once firehose format is fully shipped, remove this + for _, o := range msg.Ops { + switch o.Action { + case "delete": + if o.Prev == nil { + logger.Debug("can't invert legacy op", "action", o.Action) + rm.inductionTraceLog.Warn("commit delete op", "seq", msg.Seq, "pdsHost", host.Host, "repo", msg.Repo) + commitVerifyOkish.WithLabelValues(hostname, "del").Inc() + return repoFragment, nil + } + case "update": + if o.Prev == nil { + logger.Debug("can't invert legacy op", "action", o.Action) + rm.inductionTraceLog.Warn("commit update op", "seq", msg.Seq, "pdsHost", host.Host, "repo", msg.Repo) + commitVerifyOkish.WithLabelValues(hostname, "up").Inc() + return repoFragment, nil + } + } + } + + if msg.PrevData != nil { + c := (*cid.Cid)(msg.PrevData) + if prevRoot != nil { + if *c != prevRoot.GetCid() { + commitVerifyWarnings.WithLabelValues(hostname, "pr").Inc() + rm.inductionTraceLog.Warn("commit prevData mismatch", "seq", msg.Seq, "pdsHost", host.Host, "repo", msg.Repo) + hasWarning = true + } + } else { + // see counter below for okish "new" + } + + // check internal consistency that claimed previous root matches the rest of this message + ops, err := ParseCommitOps(msg.Ops) + if err != nil { + commitVerifyErrors.WithLabelValues(hostname, "pop").Inc() + return nil, err + } + ops, err = atrepo.NormalizeOps(ops) + if err != nil { + commitVerifyErrors.WithLabelValues(hostname, "nop").Inc() + return nil, err + } + + invTree := repoFragment.MST.Copy() + for _, op := range ops { + if err := atrepo.InvertOp(&invTree, &op); err != nil { + commitVerifyErrors.WithLabelValues(hostname, "inv").Inc() + return nil, err + } + } + computed, err := invTree.RootCID() + if err != nil { + commitVerifyErrors.WithLabelValues(hostname, "it").Inc() + return nil, err + } + if *computed != *c { + // this is self-inconsistent malformed data + commitVerifyErrors.WithLabelValues(hostname, "pd").Inc() + return nil, fmt.Errorf("inverted tree root didn't match prevData") + } + //logger.Debug("prevData matched", "prevData", c.String(), "computed", computed.String()) + + if prevRoot == nil { + commitVerifyOkish.WithLabelValues(hostname, "new").Inc() + } else if hasWarning { + commitVerifyOkish.WithLabelValues(hostname, "warn").Inc() + } else { + // TODO: would it be better to make everything "okish"? + // commitVerifyOkish.WithLabelValues(hostname, "ok").Inc() + commitVerifyOk.WithLabelValues(hostname).Inc() + } + } else { + // this source is still on old protocol without new prevData field + commitVerifyOkish.WithLabelValues(hostname, "old").Inc() + } + + return repoFragment, nil +} + +// TODO: lift back to indigo/atproto/repo util code? +func ParseCommitOps(ops []*atproto.SyncSubscribeRepos_RepoOp) ([]atrepo.Operation, error) { + out := []atrepo.Operation{} + for _, rop := range ops { + switch rop.Action { + case "create": + if rop.Cid == nil || rop.Prev != nil { + return nil, fmt.Errorf("invalid repoOp: create") + } + op := atrepo.Operation{ + Path: rop.Path, + Prev: nil, + Value: (*cid.Cid)(rop.Cid), + } + out = append(out, op) + case "delete": + if rop.Cid != nil || rop.Prev == nil { + return nil, fmt.Errorf("invalid repoOp: delete") + } + op := atrepo.Operation{ + Path: rop.Path, + Prev: (*cid.Cid)(rop.Prev), + Value: nil, + } + out = append(out, op) + case "update": + if rop.Cid == nil || rop.Prev == nil { + return nil, fmt.Errorf("invalid repoOp: update") + } + op := atrepo.Operation{ + Path: rop.Path, + Prev: (*cid.Cid)(rop.Prev), + Value: (*cid.Cid)(rop.Cid), + } + out = append(out, op) + default: + return nil, fmt.Errorf("invalid repoOp action: %s", rop.Action) + } + } + return out, nil +} + +// VerifyCommitSignature get's repo's registered public key from Identity Directory, verifies Commit +// hostname is just for metrics in case of error +func (rm *RepoManager) VerifyCommitSignature(ctx context.Context, commit *atrepo.Commit, hostname string, hasWarning *bool) error { + if rm.directory == nil { + return nil + } + xdid, err := syntax.ParseDID(commit.DID) + if err != nil { + commitVerifyErrors.WithLabelValues(hostname, "sig1").Inc() + return fmt.Errorf("bad car DID, %w", err) + } + ident, err := rm.directory.LookupDID(ctx, xdid) + if err != nil { + if rm.AllowSignatureNotFound { + // allow not-found conditions to pass without signature check + commitVerifyWarnings.WithLabelValues(hostname, "nok").Inc() + if hasWarning != nil { + *hasWarning = true + } + return nil + } + commitVerifyErrors.WithLabelValues(hostname, "sig2").Inc() + return fmt.Errorf("DID lookup failed, %w", err) + } + pk, err := ident.GetPublicKey("atproto") + if err != nil { + commitVerifyErrors.WithLabelValues(hostname, "sig3").Inc() + return fmt.Errorf("no atproto pubkey, %w", err) + } + err = commit.VerifySignature(pk) + if err != nil { + // TODO: if the DID document was stale, force re-fetch from source and re-try if pubkey has changed + commitVerifyErrors.WithLabelValues(hostname, "sig4").Inc() + return fmt.Errorf("invalid signature, %w", err) + } + return nil +} diff --git a/events/consumer.go b/events/consumer.go index f00cd79c7..09f2ee778 100644 --- a/events/consumer.go +++ b/events/consumer.go @@ -126,6 +126,7 @@ func HandleRepoStream(ctx context.Context, con *websocket.Conn, sched Scheduler, go func() { t := time.NewTicker(time.Second * 30) defer t.Stop() + failcount := 0 for { @@ -133,6 +134,12 @@ func HandleRepoStream(ctx context.Context, con *websocket.Conn, sched Scheduler, case <-t.C: if err := con.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(time.Second*10)); err != nil { log.Warn("failed to ping", "err", err) + failcount++ + if failcount >= 4 { + log.Error("too many ping fails", "count", failcount) + con.Close() + return + } } case <-ctx.Done(): con.Close() @@ -169,7 +176,7 @@ func HandleRepoStream(ctx context.Context, con *websocket.Conn, sched Scheduler, mt, rawReader, err := con.NextReader() if err != nil { - return err + return fmt.Errorf("con err at read: %w", err) } switch mt { diff --git a/events/repostream.go b/events/repostream.go deleted file mode 100644 index 01f8b4bef..000000000 --- a/events/repostream.go +++ /dev/null @@ -1,60 +0,0 @@ -package events - -import ( - "context" - - "github.com/bluesky-social/indigo/repomgr" - - "github.com/gorilla/websocket" - cid "github.com/ipfs/go-cid" -) - -type LiteStreamHandleFunc func(op repomgr.EventKind, seq int64, path string, did string, rcid *cid.Cid, rec any) error - -func ConsumeRepoStreamLite2(ctx context.Context, con *websocket.Conn, cb LiteStreamHandleFunc) error { - /* - return HandleRepoStream(ctx, con, &RepoStreamCallbacks{ - RepoCommit: func(evt *comatproto.SyncSubscribeRepos_Commit) error { - if evt.TooBig { - log.Errorf("skipping too big events for now: %d", evt.Seq) - return nil - } - r, err := repo.ReadRepoFromCar(ctx, bytes.NewReader(evt.Blocks)) - if err != nil { - return fmt.Errorf("reading repo from car (seq: %d, len: %d): %w", evt.Seq, len(evt.Blocks), err) - } - - for _, op := range evt.Ops { - ek := repomgr.EventKind(op.Action) - switch ek { - case repomgr.EvtKindCreateRecord, repomgr.EvtKindUpdateRecord: - rc, rec, err := r.GetRecord(ctx, op.Path) - if err != nil { - e := fmt.Errorf("getting record %s (%s) within seq %d for %s: %w", op.Path, *op.Cid, evt.Seq, evt.Repo, err) - log.Error(e) - continue - } - - if lexutil.LexLink(rc) != *op.Cid { - // TODO: do we even error here? - return fmt.Errorf("mismatch in record and op cid: %s != %s", rc, *op.Cid) - } - - if err := cb(ek, evt.Seq, op.Path, evt.Repo, &rc, rec); err != nil { - log.Errorf("event consumer callback (%s): %s", ek, err) - continue - } - - case repomgr.EvtKindDeleteRecord: - if err := cb(ek, evt.Seq, op.Path, evt.Repo, nil, nil); err != nil { - log.Errorf("event consumer callback (%s): %s", ek, err) - continue - } - } - } - return nil - }, - }) - */ - return nil -} diff --git a/gen/main.go b/gen/main.go index 5d5d432f2..b3ae18705 100644 --- a/gen/main.go +++ b/gen/main.go @@ -94,6 +94,7 @@ func main() { atproto.LexiconSchema{}, atproto.RepoStrongRef{}, atproto.SyncSubscribeRepos_Commit{}, + atproto.SyncSubscribeRepos_Sync{}, atproto.SyncSubscribeRepos_Handle{}, atproto.SyncSubscribeRepos_Identity{}, atproto.SyncSubscribeRepos_Account{}, diff --git a/models/models.go b/models/models.go index 9781e75bd..d61ec2311 100644 --- a/models/models.go +++ b/models/models.go @@ -104,7 +104,7 @@ type FollowRecord struct { type PDS struct { gorm.Model - Host string + Host string `gorm:"unique"` Did string SSL bool Cursor int64 diff --git a/mst/diff.go b/mst/diff.go index 1c70e82e6..235ad1bd5 100644 --- a/mst/diff.go +++ b/mst/diff.go @@ -6,7 +6,7 @@ import ( "github.com/bluesky-social/indigo/util" cid "github.com/ipfs/go-cid" - blockstore "github.com/ipfs/go-ipfs-blockstore" + cbor "github.com/ipfs/go-ipld-cbor" ) type DiffOp struct { @@ -18,7 +18,7 @@ type DiffOp struct { } // TODO: this code isn't great, should be rewritten on top of the baseline datastructures once functional and correct -func DiffTrees(ctx context.Context, bs blockstore.Blockstore, from, to cid.Cid) ([]*DiffOp, error) { +func DiffTrees(ctx context.Context, bs cbor.IpldBlockstore, from, to cid.Cid) ([]*DiffOp, error) { cst := util.CborStore(bs) if from == cid.Undef { @@ -185,7 +185,7 @@ func nodeEntriesEqual(a, b *nodeEntry) bool { return false } -func identityDiff(ctx context.Context, bs blockstore.Blockstore, root cid.Cid) ([]*DiffOp, error) { +func identityDiff(ctx context.Context, bs cbor.IpldBlockstore, root cid.Cid) ([]*DiffOp, error) { cst := util.CborStore(bs) tt := LoadMST(cst, root) diff --git a/repo/repo.go b/repo/repo.go index db66e2c97..bcde64ac9 100644 --- a/repo/repo.go +++ b/repo/repo.go @@ -43,7 +43,7 @@ type UnsignedCommit struct { type Repo struct { sc SignedCommit cst cbor.IpldStore - bs blockstore.Blockstore + bs cbor.IpldBlockstore repoCid cid.Cid @@ -74,7 +74,7 @@ func (uc *UnsignedCommit) BytesForSigning() ([]byte, error) { return buf.Bytes(), nil } -func IngestRepo(ctx context.Context, bs blockstore.Blockstore, r io.Reader) (cid.Cid, error) { +func IngestRepo(ctx context.Context, bs cbor.IpldBlockstore, r io.Reader) (cid.Cid, error) { ctx, span := otel.Tracer("repo").Start(ctx, "Ingest") defer span.End() @@ -110,7 +110,7 @@ func ReadRepoFromCar(ctx context.Context, r io.Reader) (*Repo, error) { return OpenRepo(ctx, bs, root) } -func NewRepo(ctx context.Context, did string, bs blockstore.Blockstore) *Repo { +func NewRepo(ctx context.Context, did string, bs cbor.IpldBlockstore) *Repo { cst := util.CborStore(bs) t := mst.NewEmptyMST(cst) @@ -128,7 +128,7 @@ func NewRepo(ctx context.Context, did string, bs blockstore.Blockstore) *Repo { } } -func OpenRepo(ctx context.Context, bs blockstore.Blockstore, root cid.Cid) (*Repo, error) { +func OpenRepo(ctx context.Context, bs cbor.IpldBlockstore, root cid.Cid) (*Repo, error) { cst := util.CborStore(bs) var sc SignedCommit @@ -173,7 +173,7 @@ func (r *Repo) SignedCommit() SignedCommit { return r.sc } -func (r *Repo) Blockstore() blockstore.Blockstore { +func (r *Repo) Blockstore() cbor.IpldBlockstore { return r.bs } @@ -435,11 +435,11 @@ func (r *Repo) DiffSince(ctx context.Context, oldrepo cid.Cid) ([]*mst.DiffOp, e return mst.DiffTrees(ctx, r.bs, oldTree, curptr) } -func (r *Repo) CopyDataTo(ctx context.Context, bs blockstore.Blockstore) error { +func (r *Repo) CopyDataTo(ctx context.Context, bs cbor.IpldBlockstore) error { return copyRecCbor(ctx, r.bs, bs, r.sc.Data, make(map[cid.Cid]struct{})) } -func copyRecCbor(ctx context.Context, from, to blockstore.Blockstore, c cid.Cid, seen map[cid.Cid]struct{}) error { +func copyRecCbor(ctx context.Context, from, to cbor.IpldBlockstore, c cid.Cid, seen map[cid.Cid]struct{}) error { if _, ok := seen[c]; ok { return nil } diff --git a/util/cbor.go b/util/cbor.go index 70d040f1d..d6bc6d284 100644 --- a/util/cbor.go +++ b/util/cbor.go @@ -1,12 +1,11 @@ package util import ( - blockstore "github.com/ipfs/go-ipfs-blockstore" cbor "github.com/ipfs/go-ipld-cbor" mh "github.com/multiformats/go-multihash" ) -func CborStore(bs blockstore.Blockstore) *cbor.BasicIpldStore { +func CborStore(bs cbor.IpldBlockstore) *cbor.BasicIpldStore { cst := cbor.NewCborStore(bs) cst.DefaultMultihash = mh.SHA2_256 return cst diff --git a/util/cliutil/util.go b/util/cliutil/util.go index 8991945ca..38a79af2e 100644 --- a/util/cliutil/util.go +++ b/util/cliutil/util.go @@ -283,7 +283,7 @@ func firstenv(env_var_names ...string) string { // The env vars were derived from ipfs logging library, and also respond to some GOLOG_ vars from that library, // but BSKYLOG_ variables are preferred because imported code still using the ipfs log library may misbehave // if some GOLOG values are set, especially GOLOG_FILE. -func SetupSlog(options LogOptions) (*slog.Logger, error) { +func SetupSlog(options LogOptions) (*slog.Logger, io.Writer, error) { fmt.Fprintf(os.Stderr, "SetupSlog\n") var hopts slog.HandlerOptions hopts.Level = slog.LevelInfo @@ -306,7 +306,7 @@ func SetupSlog(options LogOptions) (*slog.Logger, error) { case "error": hopts.Level = slog.LevelError default: - return nil, fmt.Errorf("unknown log level: %#v", options.LogLevel) + return nil, nil, fmt.Errorf("unknown log level: %#v", options.LogLevel) } } if options.LogFormat == "" { @@ -319,7 +319,7 @@ func SetupSlog(options LogOptions) (*slog.Logger, error) { if format == "json" || format == "text" { // ok } else { - return nil, fmt.Errorf("invalid log format: %#v", options.LogFormat) + return nil, nil, fmt.Errorf("invalid log format: %#v", options.LogFormat) } options.LogFormat = format } @@ -332,7 +332,7 @@ func SetupSlog(options LogOptions) (*slog.Logger, error) { if rotateBytesStr != "" { rotateBytes, err := strconv.ParseInt(rotateBytesStr, 10, 64) if err != nil { - return nil, fmt.Errorf("invalid BSKYLOG_ROTATE_BYTES value: %w", err) + return nil, nil, fmt.Errorf("invalid BSKYLOG_ROTATE_BYTES value: %w", err) } options.LogRotateBytes = rotateBytes } @@ -343,7 +343,7 @@ func SetupSlog(options LogOptions) (*slog.Logger, error) { if keepOldStr != "" { keepOld, err := strconv.ParseInt(keepOldStr, 10, 64) if err != nil { - return nil, fmt.Errorf("invalid BSKYLOG_ROTATE_KEEP value: %w", err) + return nil, nil, fmt.Errorf("invalid BSKYLOG_ROTATE_KEEP value: %w", err) } keepOldUnset = false options.KeepOld = int(keepOld) @@ -368,7 +368,7 @@ func SetupSlog(options LogOptions) (*slog.Logger, error) { var err error out, err = os.Create(options.LogPath) if err != nil { - return nil, fmt.Errorf("%s: %w", options.LogPath, err) + return nil, nil, fmt.Errorf("%s: %w", options.LogPath, err) } fmt.Fprintf(os.Stderr, "SetupSlog create %#v\n", options.LogPath) } @@ -379,7 +379,7 @@ func SetupSlog(options LogOptions) (*slog.Logger, error) { case "json": handler = slog.NewJSONHandler(out, &hopts) default: - return nil, fmt.Errorf("unknown log format: %#v", options.LogFormat) + return nil, nil, fmt.Errorf("unknown log format: %#v", options.LogFormat) } logger := slog.New(handler) slog.SetDefault(logger) @@ -389,7 +389,7 @@ func SetupSlog(options LogOptions) (*slog.Logger, error) { fmt.Fprintf(os.Stdout, "%s\n", filepath.Join(templateDirPart, ent.Name())) } SetIpfsWriter(out, options.LogFormat, options.LogLevel) - return logger, nil + return logger, out, nil } type logRotateWriter struct {