ATCoder 116 D (思维+贪心+栈)
D - Various Sushi
Time Limit: 2 sec / Memory Limit: 1024 MB
Score : 400 points
Problem Statement
There are
is the number of different kinds of toppings of the pieces you eat.
You want to have as much satisfaction as possible. Find this maximum satisfaction.
Constraints
- 1≤K≤N≤105
- 1≤ti≤N
- 1≤di≤109
- All values in input are integers.
Input
Input is given from Standard Input in the following format:
K d1 d2 . . . dN
Output
Print the maximum satisfaction that you can obtain.
Sample Input 1 Copy
5 3 1 9 1 7 2 6 2 5 3 1
Sample Output 1 Copy
26
If you eat Sushi 3:
- The base total deliciousness is 9+7+6=22.
- The variety bonus is 2∗2=4.
Thus, your satisfaction will be 26, which is optimal.
Sample Input 2 Copy
7 4 1 1 2 1 3 1 4 6 4 5 4 5 4 5
Sample Output 2 Copy
25
It is optimal to eat Sushi 4.
Sample Input 3 Copy
6 5 5 1000000000 2 990000000 3 980000000 6 970000000 6 960000000 4 950000000
Sample Output 3 Copy
4900000016
Note that the output may not fit into a 32-bit integer type.
题意:
给定N个结构体,每一个结构体有两个信息,分别是nub 和 val,让你从中选出K个结构体,
使之 nub 的类型数的平方+sum{val i } 最大。
思路:对于给定的结构体,按照val进行从大到小排序。
建立一个栈,用来存储贡献值小的数(先入栈的贡献值大)
预处理前k个结构体,若该结构体的nub出现过,则入栈。用sum1存储val的和,sum2存储nub种类数。总贡献就是sum1 + sum2的平方
定义一个整型变量maxn用来维护最大值
k+1到n,如果此结构体 i 的nub没有出现过,则取出栈顶元素x,用 i 来替换 x的信息,得出一个总贡献(不一定是最优解),用maxn更新一下。
最后输出maxn就ok了。注意当栈为空的时候就可以跳出循环了,因为后面的数贡献值肯定不会大于前面的)
#include <iostream> #include <cmath> #include <cstdio> #include <cstring> #include <string> #include <map> #include <iomanip> #include <algorithm> #include <queue> #include <stack> #include <set> #include <vector> //const int maxn = 1e5+5; #define ll long long #define inf 0x3f3f3f3f #define FOR(i,a,b) for( int i = a;i <= b;++i) #define bug cout<<"--------------"<<endl ll gcd(ll a,ll b){return b?gcd(b,a%b):a;} ll lcm(ll a,ll b){return a/gcd(a,b)*b;} //const int maxn = 110000; using namespace std; int n,k; int vis[1000010]; struct node { int val,nub; }a[1000010]; bool cmp(node x,node y) { return x.val > y.val; } int main() { //freopen("C:\ACM\input.txt","r",stdin); cin>>n>>k; for(int i = 1;i <= n; ++i) cin>>a[i].nub>>a[i].val; sort(a+1 ,a+1+n,cmp); stack<int>sta; ll sum1 = 0; ll sum2 = 0; //for(int i = 1;i <= n; ++i) cout<<a[i].nub<<" "<<a[i].val<<endl; for(int i = 1;i <= k; ++i) { if(vis[a[i].nub] == 1) sta.push(i); else { vis[a[i].nub] = 1; sum2++; } sum1 += a[i].val; } //cout<<sum1 + sum2 * sum2<<endl; ll maxn = sum1 + sum2 * sum2; for(int i = k+1;i <= n; ++i) { if(sta.size() == 0) break; if(vis[a[i].nub] == 1) continue; int x = sta.top(); sta.pop(); sum1 = sum1 - a[x].val + a[i].val; sum2++; vis[a[i].nub] = 1; maxn = max(maxn,sum1 + sum2 * sum2); } cout<<maxn<<endl; }