一、问题引入
目前我们所知道的一些常见的最短路算法有 dijkstra、spfa、floyd。
dijkstra 和 spfa 是单源最短路,floyd 是多源最短路。
如果我们需要在 等级的时间复杂度下求出多源最短路,并且图存在负权,那么它就叉掉了这三种最短路算法,因为 dijkstra 无法处理负权,spfa 跑 次虽然一般跑不满,但是只要卡一下,就可以卡到 的时间复杂度,floyd 时间复杂度 ,也过不了,这个时候,就出现了 Johnson 算法,它是一种依靠 dijkstra 和 spfa 的算法。
二、Johnson 算法流程
新建超级源点,给每个点连接一条边。
算出每个点到超级源点的最短距离 ,使用 spfa,因为此时存在负权。
给每条边的 更新为 , 和 为这条边连接的两个节点。
计算最短路,使用 dijkstra,此时已经不存在负权。
输出时要减去 。
三、算法正确性证明
考虑如何将 dijkstra 优化成可以求负边权的算法。
首先有一种很容易想到的思路,就是将所有边都加上一个数,使得所有边的权值都变成非负整数,但是这种想法是错的,因为如果这样就会出现路径经过的边数不同导致最短路计算错误,所以我们需要使每个点加上一个数,使得任何一个最短路多余的值进行消掉之后只剩下开头点和结尾点,才能正确算出最短路。所以我们要将每一个点 设置一个值 ,然后假设一条路径的边进行加工后,假设它的路径起点为 ,终点为 ,则路径为 ,其权值为 ,消掉后变成了 ,于是这样权值就和边的数量一点关系也没有了,目前已经证明了一半,那如何证明更新后的边权一定非负?首先对于任意一条边 ,它一定满足 ,移项后得 ,则 ,于是我们就证明了任意一条边更新后权值绝对非负。
四、例题
P5905 【模板】全源最短路(Johnson)
代码:
#include <bits/stdc++.h>using namespace std;namespace fast_IO {#define FASTIO#define IOSIZE 100000
char ibuf[IOSIZE], obuf[IOSIZE]; char *p1 = ibuf, *p2 = ibuf, *p3 = obuf;#ifdef ONLINE_JUDGE#define getchar() ((p1==p2)and(p2=(p1=ibuf)+fread(ibuf,1,IOSIZE,stdin),p1==p2)?(EOF):(*p1++))#define putchar(x) ((p3==obuf+IOSIZE)&&(fwrite(obuf,p3-obuf,1,stdout),p3=obuf),*p3++=x)#endif//fread in OJ, stdio in local
#define isdigit(ch) (ch>47&&ch<58)#define isspace(ch) (ch<33)
template<typename T> inline T read() {
T s = 0; int w = 1; char ch; while (ch = getchar(), !isdigit(ch) and (ch != EOF)) if (ch == '-') w = -1; if (ch == EOF) return false; while (isdigit(ch)) s = s * 10 + ch - 48, ch = getchar(); return s * w;
} template<typename T> inline bool read(T &s) {
s = 0; int w = 1; char ch; while (ch = getchar(), !isdigit(ch) and (ch != EOF)) if (ch == '-') w = -1; if (ch == EOF) return false; while (isdigit(ch)) s = s * 10 + ch - 48, ch = getchar(); return s *= w, true;
} inline bool read(char &s) { while (s = getchar(), isspace(s)); return true;
} inline bool read(char *s) { char ch; while (ch = getchar(), isspace(ch)); if (ch == EOF) return false; while (!isspace(ch)) *s++ = ch, ch = getchar();
*s = '\000'; return true;
} template<typename T> inline void print(T x) { if (x < 0) putchar('-'), x = -x; if (x > 9) print(x / 10); putchar(x % 10 + 48);
} inline void print(char x) { putchar(x);
} inline void print(char *x) { while (*x) putchar(*x++);
} inline void print(const char *x) { for (int i = 0; x[i]; i++) putchar(x[i]);
}#ifdef _GLIBCXX_STRING
inline bool read(std::string& s) {
s = ""; char ch; while (ch = getchar(), isspace(ch)); if (ch == EOF) return false; while (!isspace(ch)) s += ch, ch = getchar(); return true;
} inline void print(std::string x) { for (int i = 0, n = x.size(); i < n; i++) putchar(x[i]);
}#endif//string
template<typename T, typename... T1> inline int read(T& a, T1&... other) { return read(a) + read(other...);
} template<typename T, typename... T1> inline void print(T a, T1... other) { print(a); print(other...);
}
struct Fast_IO {
~Fast_IO() { fwrite(obuf, p3 - obuf, 1, stdout);
}
} io; template<typename T> Fast_IO& operator >> (Fast_IO &io, T &b) { return read(b), io;
} template<typename T> Fast_IO& operator << (Fast_IO &io, T b) { return print(b), io;
}#define cout io#define cin io#define endl '\n'}using namespace fast_IO;const int N = 3e3+5;struct node{ int x; int w; int operator<(const node&a)const
{ return w>a.w;
}
};
vector<node>a[N];int vis[N];long long h[N];long long d[N];int t[N];signed main(){ int n,m;
cin >> n >> m; for(int i = 1;i<=m;i++)
{ int x,y,w;
cin >> x >> y >> w;
a[x].push_back({y,w});
} for(int i = 1;i<=n;i++)
{
a[0].push_back({i,0});
} memset(h,0x3f,sizeof(h));
queue<int>q;
q.push(0);
vis[0] = 1;
h[0] = 0; while(q.size())
{ int x = q.front();
q.pop();
vis[x] = 0; for(int i = 0;i<a[x].size();i++)
{ int v = a[x][i].x; int w = a[x][i].w; if(h[v]>h[x]+w)
{
h[v] = h[x]+w; if(!vis[v])
{
vis[v] = 1;
q.push(v);
t[v]++; if(t[v] == n+1)
{ printf("-1"); return 0;
}
}
}
}
} for(int i = 1;i<=n;i++)
{ for(int j = 0;j<a[i].size();j++)
{ int v = a[i][j].x;
a[i][j].w+=h[i]-h[v];
}
} for(int i = 1;i<=n;i++)
{
priority_queue<node>q;
q.push({i,0}); memset(d,0x3f,sizeof(d)); memset(vis,0,sizeof(vis));
d[i] = 0; while(q.size())
{ int x = q.top().x;
q.pop(); if(vis[x])
{ continue;
}
vis[x] = 1; for(int i = 0;i<a[x].size();i++)
{ int v = a[x][i].x; int w = a[x][i].w; if(d[v]>d[x]+w)
{
d[v] = d[x]+w;
q.push({v,d[v]});
}
}
} long long sum = 0; for(int j = 1;j<=n;j++)
{ if(d[j] == d[0])
{
sum+=(long long)j*(long long)1000000000;
} else
{
sum+=j*(d[j]+h[j]-h[i]);
}
}
cout << sum << "\n";
} return 0;
}
