classdef test_aggregate_countries_WIOD < matlab.unittest.TestCase
%TEST_AGGREGATE_COUNTRIES_WIOD Unit tests for aggregate_countries_WIOD function
%
%   Run with: runtests('test_aggregate_countries_WIOD')

    properties
        FuncPath
    end

    methods(TestClassSetup)
        function setupOnce(testCase)
            % Add Functions folder to path
            testCase.FuncPath = fullfile(fileparts(fileparts(mfilename('fullpath'))), ...
                'EMuSe_Calibration_Toolkit', 'Functions');
            addpath(testCase.FuncPath);
        end
    end

    methods(TestClassTeardown)
        function teardownOnce(testCase)
            rmpath(testCase.FuncPath);
        end
    end

    methods(Static)
        function Data = createMinimalWIODData(countries, NACE)
            %CREATEMINIMALWIODDATA Create minimal synthetic WIOD Data structure
            %
            % Creates a struct mimicking the output of extract_WIOD_sectors
            % with NA and IO tables for each country pair

            numSectors = length(NACE);
            numCountries = length(countries);

            Data = struct();

            for i = 1:numCountries
                Data.(countries{i}) = struct();
                Data.(countries{i}).NA = struct();
                Data.(countries{i}).IO = struct();

                for j = 1:numCountries
                    % NA: sectors x 5 variables (Con_HH, Con_NP, CON_Gov, GFCF, Inv)
                    naData = rand(numSectors, 5) * 1000;
                    Data.(countries{i}).NA.(countries{j}) = array2table(naData, ...
                        'VariableNames', {'Con_HH', 'Con_NP', 'CON_Gov', 'GFCF', 'Inv'}, ...
                        'RowNames', cellstr(NACE));

                    % IO: sectors x sectors
                    ioData = rand(numSectors, numSectors) * 1000;
                    Data.(countries{i}).IO.(countries{j}) = array2table(ioData, ...
                        'VariableNames', cellstr(NACE), ...
                        'RowNames', cellstr(NACE));
                end
            end
        end
    end

    methods(Test)
        %% Input Validation Tests
        function testEmptyISOallThrowsError(testCase)
            testCase.verifyError(...
                @() aggregate_countries_WIOD({}, {'A'}, {'B'}, {'C'}, {'D'}, ...
                    ["Reg_a", "Reg_b", "Reg_c", "Reg_d"], ["A"], struct()), ...
                'aggregate_countries_WIOD:InvalidISOall');
        end

        function testInvalidISOallTypeThrowsError(testCase)
            testCase.verifyError(...
                @() aggregate_countries_WIOD('not a cell', {'A'}, {'B'}, {'C'}, {'D'}, ...
                    ["Reg_a", "Reg_b", "Reg_c", "Reg_d"], ["A"], struct()), ...
                'aggregate_countries_WIOD:InvalidISOall');
        end

        function testInvalidISOaTypeThrowsError(testCase)
            testCase.verifyError(...
                @() aggregate_countries_WIOD({'DEU'}, 'not a cell', {'B'}, {'C'}, {'D'}, ...
                    ["Reg_a", "Reg_b", "Reg_c", "Reg_d"], ["A"], struct()), ...
                'aggregate_countries_WIOD:InvalidISOa');
        end

        function testEmptyRegionsThrowsError(testCase)
            testCase.verifyError(...
                @() aggregate_countries_WIOD({'DEU'}, {'DEU'}, {}, {}, {}, ...
                    [], ["A"], struct()), ...
                'aggregate_countries_WIOD:EmptyRegions');
        end

        function testEmptyNACEThrowsError(testCase)
            testCase.verifyError(...
                @() aggregate_countries_WIOD({'DEU'}, {'DEU'}, {}, {}, {}, ...
                    ["Reg_a", "Reg_b", "Reg_c", "Reg_d"], [], struct()), ...
                'aggregate_countries_WIOD:EmptyNACE');
        end

        function testInvalidDataTypeThrowsError(testCase)
            testCase.verifyError(...
                @() aggregate_countries_WIOD({'DEU'}, {'DEU'}, {}, {}, {}, ...
                    ["Reg_a", "Reg_b", "Reg_c", "Reg_d"], ["A"], 'not a struct'), ...
                'aggregate_countries_WIOD:InvalidData');
        end

        %% Output Structure Tests
        function testOutputIsStruct(testCase)
            countries = {'DEU', 'FRA', 'ITA', 'ESP'};
            NACE = ["A", "B"];
            ISO_regions = ["Reg_a", "Reg_b", "Reg_c", "Reg_d"];

            Data = test_aggregate_countries_WIOD.createMinimalWIODData(countries, NACE);

            Data_final = aggregate_countries_WIOD(countries, {'DEU'}, {'FRA'}, ...
                {'ITA'}, {'ESP'}, ISO_regions, NACE, Data);

            testCase.verifyClass(Data_final, 'struct');
        end

        function testOutputHasRegionFields(testCase)
            countries = {'DEU', 'FRA', 'ITA', 'ESP'};
            NACE = ["A", "B"];
            ISO_regions = ["Reg_a", "Reg_b", "Reg_c", "Reg_d"];

            Data = test_aggregate_countries_WIOD.createMinimalWIODData(countries, NACE);

            Data_final = aggregate_countries_WIOD(countries, {'DEU'}, {'FRA'}, ...
                {'ITA'}, {'ESP'}, ISO_regions, NACE, Data);

            % Verify region fields exist
            for r = 1:length(ISO_regions)
                testCase.verifyTrue(isfield(Data_final, ISO_regions(r)), ...
                    sprintf('Missing region field: %s', ISO_regions(r)));
            end
        end

        function testOutputRegionHasNAandIO(testCase)
            countries = {'DEU', 'FRA', 'ITA', 'ESP'};
            NACE = ["A", "B"];
            ISO_regions = ["Reg_a", "Reg_b", "Reg_c", "Reg_d"];

            Data = test_aggregate_countries_WIOD.createMinimalWIODData(countries, NACE);

            Data_final = aggregate_countries_WIOD(countries, {'DEU'}, {'FRA'}, ...
                {'ITA'}, {'ESP'}, ISO_regions, NACE, Data);

            % Each region should have NA and IO subfields
            testCase.verifyTrue(isfield(Data_final.Reg_a, 'NA'));
            testCase.verifyTrue(isfield(Data_final.Reg_a, 'IO'));
        end

        function testOutputHasPsiShares(testCase)
            countries = {'DEU', 'FRA', 'ITA', 'ESP'};
            NACE = ["A", "B"];
            ISO_regions = ["Reg_a", "Reg_b", "Reg_c", "Reg_d"];

            Data = test_aggregate_countries_WIOD.createMinimalWIODData(countries, NACE);

            Data_final = aggregate_countries_WIOD(countries, {'DEU'}, {'FRA'}, ...
                {'ITA'}, {'ESP'}, ISO_regions, NACE, Data);

            % Verify Psi shares exist
            testCase.verifyTrue(isfield(Data_final.Reg_a.NA, 'Psi_C_I_G'), ...
                'Missing Psi_C_I_G in NA');
            testCase.verifyTrue(isfield(Data_final.Reg_a.IO, 'Psi_H'), ...
                'Missing Psi_H in IO');
        end

        function testOutputHasBiases(testCase)
            countries = {'DEU', 'FRA', 'ITA', 'ESP'};
            NACE = ["A", "B"];
            ISO_regions = ["Reg_a", "Reg_b", "Reg_c", "Reg_d"];

            Data = test_aggregate_countries_WIOD.createMinimalWIODData(countries, NACE);

            Data_final = aggregate_countries_WIOD(countries, {'DEU'}, {'FRA'}, ...
                {'ITA'}, {'ESP'}, ISO_regions, NACE, Data);

            % Verify biases exist
            testCase.verifyTrue(isfield(Data_final.Reg_a.NA, 'Biases_C_I'), ...
                'Missing Biases_C_I in NA');
            testCase.verifyTrue(isfield(Data_final.Reg_a.IO, 'Biases_hhh'), ...
                'Missing Biases_hhh in IO');
        end

        %% Data Integrity Tests
        function testPsiSharesSumToOne(testCase)
            countries = {'DEU', 'FRA', 'ITA', 'ESP'};
            NACE = ["A", "B"];
            ISO_regions = ["Reg_a", "Reg_b", "Reg_c", "Reg_d"];

            Data = test_aggregate_countries_WIOD.createMinimalWIODData(countries, NACE);

            Data_final = aggregate_countries_WIOD(countries, {'DEU'}, {'FRA'}, ...
                {'ITA'}, {'ESP'}, ISO_regions, NACE, Data);

            % Psi_C (column 1) should sum to 1
            Psi_C_I_G = table2array(Data_final.Reg_a.NA.Psi_C_I_G);
            testCase.verifyEqual(sum(Psi_C_I_G(:, 1)), 1, 'AbsTol', 1e-10, ...
                'Consumption shares (Psi_C) should sum to 1');

            % Psi_I (column 4) should sum to 1
            testCase.verifyEqual(sum(Psi_C_I_G(:, 4)), 1, 'AbsTol', 1e-10, ...
                'Investment shares (Psi_I) should sum to 1');
        end

        function testPsiHSharesSumToOne(testCase)
            countries = {'DEU', 'FRA', 'ITA', 'ESP'};
            NACE = ["A", "B"];
            ISO_regions = ["Reg_a", "Reg_b", "Reg_c", "Reg_d"];

            Data = test_aggregate_countries_WIOD.createMinimalWIODData(countries, NACE);

            Data_final = aggregate_countries_WIOD(countries, {'DEU'}, {'FRA'}, ...
                {'ITA'}, {'ESP'}, ISO_regions, NACE, Data);

            % Each column of Psi_H should sum to 1
            Psi_H = table2array(Data_final.Reg_a.IO.Psi_H);
            colSums = sum(Psi_H, 1);
            testCase.verifyEqual(colSums, ones(1, length(NACE)), 'AbsTol', 1e-10, ...
                'Intermediate input shares (Psi_H) should sum to 1 for each sector');
        end

        function testBiasesSumToOne(testCase)
            countries = {'DEU', 'FRA', 'ITA', 'ESP'};
            NACE = ["A", "B"];
            ISO_regions = ["Reg_a", "Reg_b", "Reg_c", "Reg_d"];

            Data = test_aggregate_countries_WIOD.createMinimalWIODData(countries, NACE);

            Data_final = aggregate_countries_WIOD(countries, {'DEU'}, {'FRA'}, ...
                {'ITA'}, {'ESP'}, ISO_regions, NACE, Data);

            % Sum biases across all regions for each sector - should sum to 1
            numSectors = length(NACE);
            bias_sum = zeros(numSectors, 5);
            for r = 1:length(ISO_regions)
                bias_sum = bias_sum + table2array(Data_final.Reg_a.NA.Biases_C_I.(ISO_regions(r)));
            end

            % Each row should sum to 1 for consumption bias (column 1)
            testCase.verifyEqual(bias_sum(:, 1), ones(numSectors, 1), 'AbsTol', 1e-10, ...
                'Home biases should sum to 1 across regions');
        end

        %% Edge Cases
        function testSingleCountryPerRegion(testCase)
            countries = {'DEU', 'FRA', 'ITA', 'ESP'};
            NACE = ["A"];
            ISO_regions = ["Reg_a", "Reg_b", "Reg_c", "Reg_d"];

            Data = test_aggregate_countries_WIOD.createMinimalWIODData(countries, NACE);

            % Should work without error
            Data_final = aggregate_countries_WIOD(countries, {'DEU'}, {'FRA'}, ...
                {'ITA'}, {'ESP'}, ISO_regions, NACE, Data);

            testCase.verifyTrue(isfield(Data_final, 'Reg_a'));
        end

        function testMultipleCountriesInRegion(testCase)
            countries = {'DEU', 'FRA', 'ITA', 'ESP', 'NLD', 'BEL'};
            NACE = ["A", "B"];
            ISO_regions = ["Reg_a", "Reg_b", "Reg_c", "Reg_d"];

            Data = test_aggregate_countries_WIOD.createMinimalWIODData(countries, NACE);

            % Region A has 3 countries, others have 1 each
            Data_final = aggregate_countries_WIOD(countries, ...
                {'DEU', 'FRA', 'ITA'}, {'ESP'}, {'NLD'}, {'BEL'}, ...
                ISO_regions, NACE, Data);

            testCase.verifyTrue(isfield(Data_final, 'Reg_a'));
            testCase.verifyTrue(isfield(Data_final.Reg_a.NA, 'Psi_C_I_G'));
        end

        function testAllRegionsPopulated(testCase)
            % All four regions must have at least one country for valid Psi calculation
            % (Empty regions cause division by zero in share computation)
            countries = {'DEU', 'FRA', 'ITA', 'ESP'};
            NACE = ["A", "B"];
            ISO_regions = ["Reg_a", "Reg_b", "Reg_c", "Reg_d"];

            Data = test_aggregate_countries_WIOD.createMinimalWIODData(countries, NACE);

            Data_final = aggregate_countries_WIOD(countries, ...
                {'DEU'}, {'FRA'}, {'ITA'}, {'ESP'}, ...
                ISO_regions, NACE, Data);

            % All regions should exist
            testCase.verifyTrue(isfield(Data_final, 'Reg_a'));
            testCase.verifyTrue(isfield(Data_final, 'Reg_b'));
            testCase.verifyTrue(isfield(Data_final, 'Reg_c'));
            testCase.verifyTrue(isfield(Data_final, 'Reg_d'));
        end
    end
end
