AtCoderのABC    前のABCの問題へ

ABC459-E Select from Subtrees


問題へのリンク


C#のソース

using System;
using System.Collections.Generic;
using System.Linq;

class Program
{
    static string InputPattern = "InputX";

    static List<string> GetInputList()
    {
        var WillReturn = new List<string>();

        if (InputPattern == "Input1") {
            WillReturn.Add("5");
            WillReturn.Add("1 1 3 3");
            WillReturn.Add("1 2 1 2 3");
            WillReturn.Add("1 1 3 1 1");
            //144
        }
        else if (InputPattern == "Input2") {
            WillReturn.Add("2");
            WillReturn.Add("1");
            WillReturn.Add("1 1");
            WillReturn.Add("2 1");
            //0
        }
        else if (InputPattern == "Input3") {
            WillReturn.Add("3");
            WillReturn.Add("3 1");
            WillReturn.Add("1000000000 1 1");
            WillReturn.Add("1 1 1");
            //1755647
        }
        else {
            string wkStr;
            while ((wkStr = Console.ReadLine()) != null) WillReturn.Add(wkStr);
        }
        return WillReturn;
    }

    static long[] GetSplitArr(string pStr)
    {
        return (pStr == "" ? new string[0] : pStr.Split(' ')).Select(pX => long.Parse(pX)).ToArray();
    }

    // 隣接リスト
    static Dictionary<long, List<long>> mToNodeListDict = new Dictionary<long, List<long>>();

    const long Hou = 998244353;

    static void Main()
    {
        List<string> InputList = GetInputList();

        long N = long.Parse(InputList[0]);

        var PDict = new Dictionary<long, long>();
        long[] PArr = GetSplitArr(InputList[1]);
        for (long I = 0; I <= PArr.GetUpperBound(0); I++) {
            PDict[I + 2] = PArr[I];
        }
        long[] CArr = GetSplitArr(InputList[2]);
        var CDict = new Dictionary<long, long>();
        for (long I = 0; I <= CArr.GetUpperBound(0); I++) {
            CDict[I + 1] = CArr[I];
        }

        long[] DArr = GetSplitArr(InputList[3]);
        var DDict = new Dictionary<long, long>();
        for (long I = 0; I <= DArr.GetUpperBound(0); I++) {
            DDict[I + 1] = DArr[I];
        }

        foreach (var Eachpair in PDict) {
            long Node = Eachpair.Key;
            long Parent = Eachpair.Value;

            if (mToNodeListDict.ContainsKey(Node) == false) {
                mToNodeListDict[Node] = new List<long>();
            }
            if (mToNodeListDict.ContainsKey(Parent) == false) {
                mToNodeListDict[Parent] = new List<long>();
            }
            mToNodeListDict[Node].Add(Parent);
            mToNodeListDict[Parent].Add(Node);
        }
        List<JyoutaiDef> DFSResult = ExecDFS(1);
        DFSResult = DFSResult.OrderByDescending(pX => pX.Level).ToList();

        // アメの個数[ノード]なDict
        var CSumDict = new Dictionary<long, long>();
        for (long I = 1; I <= N; I++) {
            CSumDict[I] = 0;
        }

        // 取るアメの個数[ノード]なDict
        var DSumDict = new Dictionary<long, long>();
        for (long I = 1; I <= N; I++) {
            DSumDict[I] = 0;
        }

        // DFSでのレベルの降順での、親ノードへの配る木DP
        foreach (JyoutaiDef EachJyoutai in DFSResult) {
            long CurrNode = EachJyoutai.Node;
            long ParentNode = EachJyoutai.ParentNode;

            CSumDict[CurrNode] += CDict[CurrNode];
            DSumDict[CurrNode] += DDict[CurrNode];
            if (ParentNode != -1) {
                CSumDict[ParentNode] += CSumDict[CurrNode];
                DSumDict[ParentNode] += DSumDict[CurrNode];
            }
        }

        long Answer = 1;
        foreach (JyoutaiDef EachJyoutai in DFSResult) {
            long CurrNode = EachJyoutai.Node;
            long ParentNode = EachJyoutai.ParentNode;

            long CSum = CSumDict[CurrNode];
            long DSum = DSumDict[CurrNode] - DDict[CurrNode];

            long Rest = CSum - DSum;
            long Take = DDict[CurrNode];

            long CurrVal = DeriveChoose(Rest, Take);
            Answer *= CurrVal;
            Answer %= Hou;
        }
        Console.WriteLine(Answer);
    }

    struct JyoutaiDef
    {
        internal long Node;
        internal long ParentNode;
        internal long Level;
    }

    static List<JyoutaiDef> ExecDFS(long pRootNode)
    {
        var WillReturn = new List<JyoutaiDef>();

        var Stk = new Stack<JyoutaiDef>();
        JyoutaiDef WillPush;
        WillPush.Node = pRootNode;
        WillPush.ParentNode = -1;
        WillPush.Level = 1;
        Stk.Push(WillPush);

        var VisitedSet = new HashSet<long>();
        VisitedSet.Add(pRootNode);

        while (Stk.Count > 0) {
            JyoutaiDef Popped = Stk.Pop();
            WillReturn.Add(Popped);
            if (mToNodeListDict.ContainsKey(Popped.Node)) {
                foreach (long EachToNode in mToNodeListDict[Popped.Node]) {
                    if (VisitedSet.Add(EachToNode)) {
                        WillPush.Node = EachToNode;
                        WillPush.ParentNode = Popped.Node;
                        WillPush.Level = Popped.Level + 1;
                        Stk.Push(WillPush);
                    }
                }
            }
        }
        return WillReturn;
    }

    // nCr (mod Hou)を求める
    static long DeriveChoose(long pN, long pR)
    {
        if (pN < pR) return 0;

        pR = Math.Min(pR, pN - pR);

        long WillReturn = 1;
        for (long I = pN - pR + 1; I <= pN; I++) {
            WillReturn *= (I % Hou);
            WillReturn %= Hou;
        }
        for (long I = 2; I <= pR; I++) {
            WillReturn *= DeriveGyakugen(I);
            WillReturn %= Hou;
        }
        return WillReturn;
    }

    // 引数の逆元を求める
    static long DeriveGyakugen(long pLong)
    {
        return DeriveBekijyou(pLong, Hou - 2, Hou);
    }

    // 繰り返し2乗法で、(NのP乗) Mod Mを求める
    static long DeriveBekijyou(long pN, long pP, long pM)
    {
        long CurrJyousuu = pN % pM;
        long CurrShisuu = 1;
        long WillReturn = 1;

        while (true) {
            // 対象ビットが立っている場合
            if ((pP & CurrShisuu) > 0) {
                WillReturn = (WillReturn * CurrJyousuu) % pM;
            }

            CurrShisuu *= 2;
            if (CurrShisuu > pP) return WillReturn;
            CurrJyousuu = (CurrJyousuu * CurrJyousuu) % pM;
        }
    }
}


解説

木DPで解いてます。