编辑代码

#include <bits/stdc++.h>
using namespace std;
#define INF 0xffff

void scan(vector<vector<int>> &node,int n)
{
    int a, b, len;
    while (n--)
    {
        scanf("%d %d %d",&a,&b,&len);
        if (node[a-1][b-1]!=INF and node[a-1][b-1]<=len)
        continue ;
        if (a!=b)
        {
            node[a-1][b-1] = len;   node[b-1][a-1] = len;
        }
    }
}

void change(vector<int> &side  ,  vector<int> &low_cost ,vector<int> &visited )
{
    for(int i=0;i<side.size();i++){
        if (visited[i]==0)  low_cost[i]=INF;
        else{
            low_cost[i]=min(low_cost[i],side[i]);
        }
    }
}

void prim(vector<vector<int>> &node,vector<int> &visited,int &ans)
{
    vector <int> low_cost(visited.size(),INF);
    int s=0;
    

    for(int j=0;j<visited.size()-1;j++)
    {
        visited[s]=0;
        vector<int> side= node[s];
        change(side,low_cost,visited);

        auto i=min_element(low_cost.begin(),low_cost.end());
        if (*i==INF) {
            cout<<"orz";    return;
        }

        ans+=*i;    
        s=distance(low_cost.begin(),i);
    }
    cout<<ans;
    return;
}

int main()
{
    int n,m;
    cin>>n>>m;
    vector<vector<int>> node(n, vector<int>(n, INF));
    scan(node,m);

    int ans=0;
    vector<int> visited (n,1);

    prim(node,visited,ans);

    return 0;
}