AcWing 1221. 四平方和 (枚举 + 二分/哈希)

  • 时间:
  • 浏览:
  • 来源:互联网

题目描述

原题链接

分析

先考虑一下暴力做法:
即枚举 a , b , c a,b,c a,b,c, 求出 d d d, 判断 a 2 + b 2 + c 2 + d 2 是 否 等 于 n a^2+b^2+c^2+d^2 是否等于 n a2+b2+c2+d2n, 大约 1 e 9 1e9 1e9的复杂度,肯定会超时
下面考虑如何优化暴力做法:
不妨以空间换时间:
先枚举 c , d c,d c,d的所有可能,储存下来,再枚举 a , b a,b a,b,查找之前存储下来的 c , d c,d c,d,是否有符合要求的,从而得到答案
(符合要求即: a 2 + b 2 + c 2 + d 2 = n a^2+b^2+c^2+d^2 = n a2+b2+c2+d2=n a , b , c , d a,b,c,d a,b,c,d按联合主键上升)


如何查找出符合要求的 c , d c,d c,d呢? 这里有两种做法(二分/哈希)
二分 O ( n 2 l o g n ) O(n^2logn) O(n2logn)
对储存下来的 c , d c,d c,d,按 c 2 + d 2 , c , d c^2+d^2,c,d c2+d2,c,d的优先级升序排列, 而后二分出一个最小的符合要求的 r e s res res,使 r e s = n − a 2 − b 2 res = n - a^2 - b^2 res=na2b2, 从而求出 c , d c,d c,d
至于如何满足联合主键上升: 枚举 a , b a,b a,b, 将 c , d c,d c,d按上述规则排序即可
哈希 O ( n 2 ) O(n^2) O(n2):
建立 r e s res res c c c的哈希映射, 直接 O ( 1 ) O(1) O(1)找出符合要求的 r e s res res, 使 r e s = n − a 2 − b 2 res = n - a^2 - b^2 res=na2b2, 从而求出 c , d c,d c,d.
而建立哈希映射的过程,即可保证 c , d c,d c,d按联合主键上升
N < 5 e 6 N<5e6 N<5e6, 可以直接利用数组建立映射, 不要使用unodered_map,会超时.

Y总视频讲解(需要权限)

实现

枚举 + 二分

#include <cstdio>
#include <iostream>
#include <algorithm>
using namespace std;
const int N = 5e6 + 9;
struct node
{
    int c, d, sum;
};
node no[N*10];
int n, cnt;
bool cmp(node x, node y)
{
    if(x.sum != y.sum) return x.sum < y.sum;
    if(x.c != y.c) return x.c < y.c;
    if(x.d != y.d) return x.d < y.d;
}
int solve(int res)
{
    int left = 0, right = cnt;
    while(left < right)
    {
        int mid = (left + right) / 2;
        if(no[mid].sum >= res) right = mid;
        else left = mid + 1;
    }
    if(no[left].sum == res) return left;
    return -1;
}
int main()
{
    cin >> n;
    for(int i=0; i*i <= n; i++)
    {
        for(int j=i; j*j + i*i <=n; j++)
        {
            no[cnt++] = {i, j, i*i + j*j};
        }
    }
    sort(no,no+cnt,cmp);
    for(int i=0; i*i <= n; i++)
    {
        for(int j=i; i*i + j*j <=n; j++)
        {
            int res = n - i*i - j*j;
            int index = solve(res);
            if(index == -1) continue;
            cout << i << " " << j << " " << no[index].c << " " << no[index].d << endl;
            return 0;
        }
    }
    return 0;
}

枚举 + 哈希

#include <cstdio>
#include <iostream>
#include <cmath>
using namespace std;
const int N = 5e6 + 9;
int vis[N];
int map[N];
int n;
int main()
{
    cin >> n;
    for(int i=0; i*i<=n; i++)
    {
        for(int j=i; j*j +i*i <= n; j++)
        {
            int sum = i*i + j*j;
            if(!vis[sum])
            {
                map[sum] = i;
                vis[sum] = 1;
            }
        }
    }
    for(int i=0; i*i<=n; i++)
    {
        for(int j=i; j*j + i*i <= n; j++)
        {
            int sum = n - i*i - j*j;
            if(vis[sum])
            {
                int k = map[sum];
                int l = sqrt(n - i*i - j*j - k*k);
                cout << i << " " << j << " " << k << " " << l << endl;
                return 0;
            }
        }
    }
    return 0;
}

本文链接http://www.dzjqx.cn/news/show-617573.html