classdef test_compute_wiod_means < matlab.unittest.TestCase
%TEST_COMPUTE_WIOD_MEANS Unit tests for compute_wiod_means function
%
%   Run with: runtests('test_compute_wiod_means')

    properties
        FuncPath
    end

    methods(TestClassSetup)
        function setupOnce(testCase)
            % Add Functions folder to path
            testCase.FuncPath = fullfile(fileparts(fileparts(mfilename('fullpath'))), ...
                'EMuSe_Calibration_Toolkit', 'Functions');
            addpath(testCase.FuncPath);
        end
    end

    methods(TestClassTeardown)
        function teardownOnce(testCase)
            rmpath(testCase.FuncPath);
        end
    end

    methods(Static)
        function Calibration = createMinimalCalibration(NACE, ISO_regions, year_codes)
            %CREATEMINIMALCALIBRATION Create minimal synthetic Calibration structure
            %
            % Creates a struct mimicking the output of aggregate_countries_WIOD
            % with WIOD data for multiple years

            numSectors = length(NACE);
            numRegions = length(ISO_regions);
            numYears = length(year_codes);

            Calibration = struct();
            Calibration.WIOD = struct();

            for y = 1:numYears
                Calibration.WIOD.(year_codes{y}) = struct();

                for r = 1:numRegions
                    % NA structure with Psi_C_I_G
                    psiData = rand(numSectors, 5);
                    psiData(:, 1) = psiData(:, 1) / sum(psiData(:, 1));  % Normalize Psi_C
                    psiData(:, 4) = psiData(:, 4) / sum(psiData(:, 4));  % Normalize Psi_I

                    Calibration.WIOD.(year_codes{y}).(ISO_regions{r}).NA.Psi_C_I_G = ...
                        array2table(psiData, ...
                        'VariableNames', {'Psi_C', 'Psi_Con_NP', 'Psi_G', 'Psi_I', 'Psi_Inventories'}, ...
                        'RowNames', cellstr(NACE));

                    % IO structure with Psi_H
                    psiHData = rand(numSectors, numSectors);
                    psiHData = psiHData ./ sum(psiHData, 1);  % Normalize columns
                    Calibration.WIOD.(year_codes{y}).(ISO_regions{r}).IO.Psi_H = ...
                        array2table(psiHData, ...
                        'VariableNames', cellstr(NACE), ...
                        'RowNames', cellstr(NACE));

                    % Biases structures
                    Calibration.WIOD.(year_codes{y}).(ISO_regions{r}).NA.Biases_C_I = struct();
                    Calibration.WIOD.(year_codes{y}).(ISO_regions{r}).IO.Biases_hhh = struct();

                    for r2 = 1:numRegions
                        biasData = rand(numSectors, 5);
                        Calibration.WIOD.(year_codes{y}).(ISO_regions{r}).NA.Biases_C_I.(ISO_regions{r2}) = ...
                            array2table(biasData, ...
                            'VariableNames', {'HB_C', 'HB_NP', 'HB_G', 'HB_I', 'HB_Inventories'}, ...
                            'RowNames', cellstr(NACE));

                        biasHData = rand(numSectors, numSectors);
                        Calibration.WIOD.(year_codes{y}).(ISO_regions{r}).IO.Biases_hhh.(ISO_regions{r2}) = ...
                            array2table(biasHData, ...
                            'VariableNames', cellstr(NACE), ...
                            'RowNames', cellstr(NACE));
                    end
                end
            end
        end
    end

    methods(Test)
        %% Input Validation Tests
        function testInvalidCalibrationTypeThrowsError(testCase)
            testCase.verifyError(...
                @() compute_wiod_means('not a struct', ["A"], ["Reg_a"], {'y05'}), ...
                'compute_wiod_means:InvalidCalibration');
        end

        function testMissingWIODFieldThrowsError(testCase)
            Calibration = struct('SEA', struct());
            testCase.verifyError(...
                @() compute_wiod_means(Calibration, ["A"], ["Reg_a"], {'y05'}), ...
                'compute_wiod_means:MissingWIOD');
        end

        function testEmptyNACEThrowsError(testCase)
            Calibration = struct('WIOD', struct());
            testCase.verifyError(...
                @() compute_wiod_means(Calibration, [], ["Reg_a"], {'y05'}), ...
                'compute_wiod_means:EmptyNACE');
        end

        function testEmptyRegionsThrowsError(testCase)
            Calibration = struct('WIOD', struct());
            testCase.verifyError(...
                @() compute_wiod_means(Calibration, ["A"], [], {'y05'}), ...
                'compute_wiod_means:EmptyRegions');
        end

        function testEmptyYearCodesThrowsError(testCase)
            Calibration = struct('WIOD', struct());
            testCase.verifyError(...
                @() compute_wiod_means(Calibration, ["A"], ["Reg_a"], {}), ...
                'compute_wiod_means:EmptyYearCodes');
        end

        %% Output Structure Tests
        function testOutputHasSumField(testCase)
            NACE = ["A", "B"];
            ISO_regions = ["Reg_a", "Reg_b"];
            year_codes = {'y05', 'y06'};

            Calibration = test_compute_wiod_means.createMinimalCalibration(NACE, ISO_regions, year_codes);
            result = compute_wiod_means(Calibration, NACE, ISO_regions, year_codes);

            testCase.verifyTrue(isfield(result.WIOD, 'sum'), ...
                'Output should have WIOD.sum field');
        end

        function testOutputHasMeanField(testCase)
            NACE = ["A", "B"];
            ISO_regions = ["Reg_a", "Reg_b"];
            year_codes = {'y05', 'y06'};

            Calibration = test_compute_wiod_means.createMinimalCalibration(NACE, ISO_regions, year_codes);
            result = compute_wiod_means(Calibration, NACE, ISO_regions, year_codes);

            testCase.verifyTrue(isfield(result.WIOD, 'mean'), ...
                'Output should have WIOD.mean field');
        end

        function testOutputMeanHasRegionFields(testCase)
            NACE = ["A", "B"];
            ISO_regions = ["Reg_a", "Reg_b"];
            year_codes = {'y05', 'y06'};

            Calibration = test_compute_wiod_means.createMinimalCalibration(NACE, ISO_regions, year_codes);
            result = compute_wiod_means(Calibration, NACE, ISO_regions, year_codes);

            for r = 1:length(ISO_regions)
                testCase.verifyTrue(isfield(result.WIOD.mean, ISO_regions(r)), ...
                    sprintf('Missing region field in mean: %s', ISO_regions(r)));
            end
        end

        function testOutputMeanHasNAandIO(testCase)
            NACE = ["A", "B"];
            ISO_regions = ["Reg_a"];
            year_codes = {'y05', 'y06'};

            Calibration = test_compute_wiod_means.createMinimalCalibration(NACE, ISO_regions, year_codes);
            result = compute_wiod_means(Calibration, NACE, ISO_regions, year_codes);

            testCase.verifyTrue(isfield(result.WIOD.mean.Reg_a, 'NA'), ...
                'Missing NA field in mean.Reg_a');
            testCase.verifyTrue(isfield(result.WIOD.mean.Reg_a, 'IO'), ...
                'Missing IO field in mean.Reg_a');
        end

        function testOutputMeanHasPsiTables(testCase)
            NACE = ["A", "B"];
            ISO_regions = ["Reg_a"];
            year_codes = {'y05', 'y06'};

            Calibration = test_compute_wiod_means.createMinimalCalibration(NACE, ISO_regions, year_codes);
            result = compute_wiod_means(Calibration, NACE, ISO_regions, year_codes);

            testCase.verifyTrue(istable(result.WIOD.mean.Reg_a.NA.Psi_C_I_G), ...
                'Psi_C_I_G should be a table');
            testCase.verifyTrue(istable(result.WIOD.mean.Reg_a.IO.Psi_H), ...
                'Psi_H should be a table');
        end

        function testOutputMeanHasBiases(testCase)
            NACE = ["A", "B"];
            ISO_regions = ["Reg_a", "Reg_b"];
            year_codes = {'y05', 'y06'};

            Calibration = test_compute_wiod_means.createMinimalCalibration(NACE, ISO_regions, year_codes);
            result = compute_wiod_means(Calibration, NACE, ISO_regions, year_codes);

            testCase.verifyTrue(isfield(result.WIOD.mean.Reg_a.NA, 'Biases_C_I'), ...
                'Missing Biases_C_I in mean.Reg_a.NA');
            testCase.verifyTrue(isfield(result.WIOD.mean.Reg_a.IO, 'Biases_hhh'), ...
                'Missing Biases_hhh in mean.Reg_a.IO');
        end

        %% Correctness Tests
        function testMeanIsCorrectForSingleYear(testCase)
            % With a single year, mean should equal the original value
            NACE = ["A", "B"];
            ISO_regions = ["Reg_a"];
            year_codes = {'y05'};

            Calibration = test_compute_wiod_means.createMinimalCalibration(NACE, ISO_regions, year_codes);
            original_psi = table2array(Calibration.WIOD.y05.Reg_a.NA.Psi_C_I_G);

            result = compute_wiod_means(Calibration, NACE, ISO_regions, year_codes);
            mean_psi = table2array(result.WIOD.mean.Reg_a.NA.Psi_C_I_G);

            testCase.verifyEqual(mean_psi, original_psi, 'AbsTol', 1e-10, ...
                'Mean of single year should equal original value');
        end

        function testMeanIsAverageOfTwoYears(testCase)
            % With two years, mean should be average
            NACE = ["A"];
            ISO_regions = ["Reg_a"];
            year_codes = {'y05', 'y06'};

            % Create calibration with known values
            Calibration = struct();
            Calibration.WIOD.y05.Reg_a.NA.Psi_C_I_G = array2table([1 2 3 4 5], ...
                'VariableNames', {'Psi_C', 'Psi_Con_NP', 'Psi_G', 'Psi_I', 'Psi_Inventories'}, ...
                'RowNames', {'A'});
            Calibration.WIOD.y06.Reg_a.NA.Psi_C_I_G = array2table([3 4 5 6 7], ...
                'VariableNames', {'Psi_C', 'Psi_Con_NP', 'Psi_G', 'Psi_I', 'Psi_Inventories'}, ...
                'RowNames', {'A'});

            Calibration.WIOD.y05.Reg_a.IO.Psi_H = array2table([0.5], 'VariableNames', {'A'}, 'RowNames', {'A'});
            Calibration.WIOD.y06.Reg_a.IO.Psi_H = array2table([0.7], 'VariableNames', {'A'}, 'RowNames', {'A'});

            Calibration.WIOD.y05.Reg_a.NA.Biases_C_I.Reg_a = array2table([1 1 1 1 1], ...
                'VariableNames', {'HB_C', 'HB_NP', 'HB_G', 'HB_I', 'HB_Inventories'}, 'RowNames', {'A'});
            Calibration.WIOD.y06.Reg_a.NA.Biases_C_I.Reg_a = array2table([1 1 1 1 1], ...
                'VariableNames', {'HB_C', 'HB_NP', 'HB_G', 'HB_I', 'HB_Inventories'}, 'RowNames', {'A'});

            Calibration.WIOD.y05.Reg_a.IO.Biases_hhh.Reg_a = array2table([0.5], 'VariableNames', {'A'}, 'RowNames', {'A'});
            Calibration.WIOD.y06.Reg_a.IO.Biases_hhh.Reg_a = array2table([0.5], 'VariableNames', {'A'}, 'RowNames', {'A'});

            result = compute_wiod_means(Calibration, NACE, ISO_regions, year_codes);
            mean_psi = table2array(result.WIOD.mean.Reg_a.NA.Psi_C_I_G);

            expected = [2 3 4 5 6];  % Average of [1 2 3 4 5] and [3 4 5 6 7]
            testCase.verifyEqual(mean_psi, expected, 'AbsTol', 1e-10, ...
                'Mean should be average of two years');
        end

        function testSumIsCorrect(testCase)
            % Test that sum field contains correct sum
            NACE = ["A"];
            ISO_regions = ["Reg_a"];
            year_codes = {'y05', 'y06'};

            % Create calibration with known values
            Calibration = struct();
            Calibration.WIOD.y05.Reg_a.NA.Psi_C_I_G = array2table([1 2 3 4 5], ...
                'VariableNames', {'Psi_C', 'Psi_Con_NP', 'Psi_G', 'Psi_I', 'Psi_Inventories'}, ...
                'RowNames', {'A'});
            Calibration.WIOD.y06.Reg_a.NA.Psi_C_I_G = array2table([3 4 5 6 7], ...
                'VariableNames', {'Psi_C', 'Psi_Con_NP', 'Psi_G', 'Psi_I', 'Psi_Inventories'}, ...
                'RowNames', {'A'});

            Calibration.WIOD.y05.Reg_a.IO.Psi_H = array2table([0.5], 'VariableNames', {'A'}, 'RowNames', {'A'});
            Calibration.WIOD.y06.Reg_a.IO.Psi_H = array2table([0.7], 'VariableNames', {'A'}, 'RowNames', {'A'});

            Calibration.WIOD.y05.Reg_a.NA.Biases_C_I.Reg_a = array2table([1 1 1 1 1], ...
                'VariableNames', {'HB_C', 'HB_NP', 'HB_G', 'HB_I', 'HB_Inventories'}, 'RowNames', {'A'});
            Calibration.WIOD.y06.Reg_a.NA.Biases_C_I.Reg_a = array2table([1 1 1 1 1], ...
                'VariableNames', {'HB_C', 'HB_NP', 'HB_G', 'HB_I', 'HB_Inventories'}, 'RowNames', {'A'});

            Calibration.WIOD.y05.Reg_a.IO.Biases_hhh.Reg_a = array2table([0.5], 'VariableNames', {'A'}, 'RowNames', {'A'});
            Calibration.WIOD.y06.Reg_a.IO.Biases_hhh.Reg_a = array2table([0.5], 'VariableNames', {'A'}, 'RowNames', {'A'});

            result = compute_wiod_means(Calibration, NACE, ISO_regions, year_codes);
            sum_psi = result.WIOD.sum.Reg_a.NA.Psi_C_I_G;

            expected = [4 6 8 10 12];  % Sum of [1 2 3 4 5] and [3 4 5 6 7]
            testCase.verifyEqual(sum_psi, expected, 'AbsTol', 1e-10, ...
                'Sum should be sum of two years');
        end

        %% Table Structure Tests
        function testOutputTableHasCorrectRowNames(testCase)
            NACE = ["A", "B", "C"];
            ISO_regions = ["Reg_a"];
            year_codes = {'y05'};

            Calibration = test_compute_wiod_means.createMinimalCalibration(NACE, ISO_regions, year_codes);
            result = compute_wiod_means(Calibration, NACE, ISO_regions, year_codes);

            rowNames = result.WIOD.mean.Reg_a.NA.Psi_C_I_G.Properties.RowNames;
            testCase.verifyEqual(rowNames, cellstr(NACE)', ...
                'Row names should match NACE codes');
        end

        function testOutputTableHasCorrectVariableNames(testCase)
            NACE = ["A", "B"];
            ISO_regions = ["Reg_a"];
            year_codes = {'y05'};

            Calibration = test_compute_wiod_means.createMinimalCalibration(NACE, ISO_regions, year_codes);
            result = compute_wiod_means(Calibration, NACE, ISO_regions, year_codes);

            varNames = result.WIOD.mean.Reg_a.NA.Psi_C_I_G.Properties.VariableNames;
            expected = {'Psi_C', 'Psi_Con_NP', 'Psi_G', 'Psi_I', 'Psi_Inventories'};
            testCase.verifyEqual(varNames, expected, ...
                'Variable names should match expected Psi names');
        end
    end
end
