classdef test_extract_SEA_sectors < matlab.unittest.TestCase
%TEST_EXTRACT_SEA_SECTORS Unit tests for extract_SEA_sectors function
%
%   Run with: runtests('test_extract_SEA_sectors')

    properties
        FuncPath
    end

    methods(TestClassSetup)
        function setupOnce(testCase)
            % Add Functions folder to path
            testCase.FuncPath = fullfile(fileparts(fileparts(mfilename('fullpath'))), ...
                'EMuSe_Calibration_Toolkit', 'Functions');
            addpath(testCase.FuncPath);
        end
    end

    methods(TestClassTeardown)
        function teardownOnce(testCase)
            rmpath(testCase.FuncPath);
        end
    end

    methods(Static)
        function SEA_table = createMinimalSEA(countries, sectors, years)
            %CREATEMINIMALSEA Create minimal synthetic SEA data for testing
            %
            % Creates a table mimicking WIOD SEA structure with columns:
            % country, variable, description, code, y00, y01, ..., y16

            if nargin < 3
                years = 2000:2016;
            end

            vars = {'VA', 'GO', 'II', 'LAB', 'COMP', 'EMP', 'H_EMPE', 'EMPE', 'K'};
            numVars = length(vars);
            numCountries = length(countries);
            numSectors = length(sectors);
            numYears = length(years);

            totalRows = numCountries * numVars * numSectors;

            % Create year column names
            yearCols = cell(1, numYears);
            for y = 1:numYears
                yearCols{y} = sprintf('y%02d', mod(years(y), 100));
            end

            % Pre-allocate
            country_col = cell(totalRows, 1);
            variable_col = cell(totalRows, 1);
            description_col = cell(totalRows, 1);
            code_col = cell(totalRows, 1);
            data_cols = zeros(totalRows, numYears);

            rowIdx = 1;
            for c = 1:numCountries
                for v = 1:numVars
                    for s = 1:numSectors
                        country_col{rowIdx} = countries{c};
                        variable_col{rowIdx} = vars{v};
                        description_col{rowIdx} = sprintf('%s - %s', sectors{s}, vars{v});
                        code_col{rowIdx} = sectors{s};
                        data_cols(rowIdx, :) = rand(1, numYears) * 1000;
                        rowIdx = rowIdx + 1;
                    end
                end
            end

            % Create table
            SEA_table = table(categorical(country_col), categorical(variable_col), ...
                categorical(description_col), categorical(code_col), ...
                'VariableNames', {'country', 'variable', 'description', 'code'});

            % Add year columns
            for y = 1:numYears
                SEA_table.(yearCols{y}) = data_cols(:, y);
            end
        end
    end

    methods(Test)
        %% Input Validation Tests
        function testEmptySEADataThrowsError(testCase)
            emptyTable = table();
            testCase.verifyError(...
                @() extract_SEA_sectors(emptyTable, 'DEU'), ...
                'extract_SEA_sectors:InvalidSEA');
        end

        function testInvalidSEADataTypeThrowsError(testCase)
            testCase.verifyError(...
                @() extract_SEA_sectors('not a table', 'DEU'), ...
                'extract_SEA_sectors:InvalidSEA');
        end

        function testInvalidISOTypeThrowsError(testCase)
            SEA = test_extract_SEA_sectors.createMinimalSEA({'DEU'}, {'A01'});
            testCase.verifyError(...
                @() extract_SEA_sectors(SEA, 123), ...
                'extract_SEA_sectors:InvalidISO');
        end

        %% Output Structure Tests
        function testOutputCount(testCase)
            % Function returns 22 output structures
            SEA = test_extract_SEA_sectors.createMinimalSEA({'DEU'}, {'A01', 'A02', 'A03', 'B'});

            outputs = cell(1, 22);
            [outputs{:}] = extract_SEA_sectors(SEA, 'DEU');

            % Verify we got 22 outputs
            testCase.verifyEqual(length(outputs), 22);

            % Verify each output is a struct
            for i = 1:22
                testCase.verifyClass(outputs{i}, 'struct');
            end
        end

        function testOutputVariableFields(testCase)
            % Each output structure should have fields: VA, GO, II, LAB, COMP, EMP, H_EMPE, EMPE, K
            SEA = test_extract_SEA_sectors.createMinimalSEA({'DEU'}, {'A01', 'A02', 'A03'});

            [A_tot, ~] = extract_SEA_sectors(SEA, 'DEU');

            expectedFields = {'VA', 'GO', 'II', 'LAB', 'EMP', 'K'};
            for i = 1:length(expectedFields)
                testCase.verifyTrue(isfield(A_tot, expectedFields{i}), ...
                    sprintf('Missing field: %s', expectedFields{i}));
            end
        end

        function testOutputDataShape(testCase)
            % Output data should be column vectors (years x 1)
            numYears = 17;  % 2000-2016
            SEA = test_extract_SEA_sectors.createMinimalSEA({'DEU'}, {'A01', 'A02', 'A03'});

            [A_tot, ~] = extract_SEA_sectors(SEA, 'DEU');

            % VA should be a column vector with numYears elements
            testCase.verifyEqual(size(A_tot.VA), [numYears, 1]);
        end

        %% Sector Aggregation Tests
        function testAgricultureAggregation(testCase)
            % A_tot should aggregate A01, A02, A03
            sectors = {'A01', 'A02', 'A03', 'B'};
            SEA = test_extract_SEA_sectors.createMinimalSEA({'DEU'}, sectors);

            [A_tot, B_tot, ~] = extract_SEA_sectors(SEA, 'DEU');

            % Both should have data
            testCase.verifyFalse(all(A_tot.VA == 0), 'A_tot.VA should not be all zeros');
            testCase.verifyFalse(all(B_tot.VA == 0), 'B_tot.VA should not be all zeros');
        end

        function testManufacturingAggregation(testCase)
            % C_tot should aggregate all manufacturing sectors
            % Using a subset for testing
            sectors = {'C10-C12', 'C13-C15', 'C16'};
            SEA = test_extract_SEA_sectors.createMinimalSEA({'DEU'}, sectors);

            [~, ~, C_tot, ~] = extract_SEA_sectors(SEA, 'DEU');

            % C_tot should have aggregated data
            testCase.verifyFalse(all(C_tot.VA == 0), 'C_tot.VA should not be all zeros');
        end

        %% Country Extraction Tests
        function testSingleCountryExtraction(testCase)
            countries = {'DEU', 'FRA', 'ITA'};
            sectors = {'A01', 'A02', 'A03'};
            SEA = test_extract_SEA_sectors.createMinimalSEA(countries, sectors);

            % Extract only DEU
            [A_tot, ~] = extract_SEA_sectors(SEA, 'DEU');

            % Should get data for DEU only
            testCase.verifyFalse(isempty(A_tot.VA));
        end

        function testCountryAsCellArray(testCase)
            % ISO can be passed as cell array
            sectors = {'A01', 'A02', 'A03'};
            SEA = test_extract_SEA_sectors.createMinimalSEA({'DEU'}, sectors);

            % This should work without error
            [A_tot, ~] = extract_SEA_sectors(SEA, {'DEU'});
            testCase.verifyFalse(isempty(A_tot.VA));
        end

        function testCountryAsString(testCase)
            % ISO can be passed as string
            sectors = {'A01', 'A02', 'A03'};
            SEA = test_extract_SEA_sectors.createMinimalSEA({'DEU'}, sectors);

            % This should work without error
            [A_tot, ~] = extract_SEA_sectors(SEA, "DEU");
            testCase.verifyFalse(isempty(A_tot.VA));
        end

        %% Edge Cases
        function testEmptyResultForMissingSectors(testCase)
            % If a sector doesn't exist in data, output should be zeros
            sectors = {'A01'};  % Only A01, not full agriculture
            SEA = test_extract_SEA_sectors.createMinimalSEA({'DEU'}, sectors);

            [A_tot, B_tot, ~] = extract_SEA_sectors(SEA, 'DEU');

            % A_tot should have data (from A01)
            testCase.verifyFalse(all(A_tot.VA == 0));
            % B_tot should be zeros (no B sector in data)
            testCase.verifyTrue(all(B_tot.VA == 0), 'B_tot should be zeros when B sector is missing');
        end

        function testYearColumnsPreserved(testCase)
            % Verify that year data is extracted correctly
            years = 2000:2016;
            sectors = {'A01'};
            SEA = test_extract_SEA_sectors.createMinimalSEA({'DEU'}, sectors, years);

            [A_tot, ~] = extract_SEA_sectors(SEA, 'DEU');

            % Should have 17 years of data
            testCase.verifyEqual(length(A_tot.VA), 17);
        end
    end
end
