Создание функции, которая может принимать различные типы входных данных шаблонного класса

#c #templates #matrix

Вопрос:

У меня есть приложение, в котором мне нужно уменьшить использование памяти моими матричными классами. Многие из моих матриц симметричны, и поэтому должны требовать только N(N 1)/2 памяти вместо N 2. Я не могу просто иметь MatrixSym класс, производный от моего Matrix класса, потому что дочерний класс всегда использует больше памяти, чем базовый класс. Я определил MatrixAbstract класс, от которого наследуются оба класса, однако у меня возникли проблемы с получением функции умножения матриц для работы с ними обоими.

Мой вопрос

Как я могу определить C = multiply(A,B) функцию, которая принимает Matrix MatrixSym объект или для любого ввода/вывода? Мой минимальный рабочий пример не так уж минимален, потому что он содержит четыре определения multiply , которые все используют один и тот же код и должны быть объединены в одно.

Требования

  • Только статическое распределение
  • Необходимо использовать std::array для элементов матрицы
  • Минимальный повторяющийся код: У меня много матричных операций, и единственное, что отличается MatrixSym , — это способ хранения и доступа к элементам из базового matVals массива.
 #include <array> 
#include <iostream> 

template<unsigned int ROWS, unsigned int COLS, unsigned int NUMEL>
class MatrixAbstract
{
private:
    static constexpr unsigned int numel = NUMEL; 
    std::array<double, NUMEL> matVals;   // The array of matrix elements
public:
    MatrixSuperclass(){}
    virtual unsigned int index_from_rc(const unsigned intamp; row, const unsigned intamp; col) const = 0;
    // get the value at a given row and column
    double get_value(const intamp; row, const intamp; col) const
    {
        return this->matVals[this->index_from_rc(row,col)];
    }
    // set the value, given the row and column
    void set_value(const intamp; row, const intamp; col, double value)
    {
        this->matVals[this->index_from_rc(row,col)] = value;
    }

};

template<unsigned int ROWS, unsigned int COLS, unsigned int NUMEL = ROWS*COLS>
class Matrix : public MatrixSuperclass<ROWS, COLS, NUMEL>
{
public:
    Matrix(){}
    // get the linear index in matVals corresponding to a row,column input
    unsigned int index_from_rc(const unsigned intamp; row, const unsigned intamp; col) const {
        return row*COLS   col; 
    }
};

template<unsigned int ROWS, unsigned int COLS = ROWS, unsigned int NUMEL = ROWS*(ROWS 1)/2> 
class MatrixSym : public MatrixSuperclass<ROWS, COLS, NUMEL>
{
public:
    MatrixSym(){}
    // get the linear index in matVals corresponding to a row,column input (Symmetric matrix)
    unsigned int index_from_rc(const unsigned intamp; row, const unsigned intamp; col) const {
        unsigned int z;
        return ( ( z = ( row < col ? col : row ) ) * ( z   1 ) >> 1 )   ( col < row ? col : row ) ;
    }
};

// THE FOLLOWING FOUR FUNCTIONS ALL USE THE EXACT SAME CODE, ONLY INPUT/OUTPUT TYPES CHANGE

// Multiply a Matrix and Matrix and output a Matrix
template<unsigned int ROWS, unsigned int COLS, unsigned int INNER>
Matrix<ROWS,COLS> multiply (Matrix<ROWS,INNER>amp; inMatrix1, Matrix<INNER,COLS>amp; inMatrix2) {
    Matrix<ROWS,COLS> outMatrix;
    for (unsigned int r = 0; r < ROWS; r  ) {
        for (unsigned int c = 0; c < COLS; c  ) {
            double val = 0.0;
            for (unsigned int rc = 0; rc < INNER; rc  ) {
                val  = inMatrix1.get_value(r,rc)*inMatrix2.get_value(rc,c);
            }
            outMatrix.set_value(r,c,val);
        }
    }
    return outMatrix;
}

// Multiply a Matrix and MatrixSym and output a Matrix
template<unsigned int ROWS, unsigned int COLS, unsigned int INNER>
Matrix<ROWS,COLS> multiply (Matrix<ROWS,INNER>amp; inMatrix1, MatrixSym<INNER,COLS>amp; inMatrix2) {
    Matrix<ROWS,COLS> outMatrix;
    for (unsigned int r = 0; r < ROWS; r  ) {
        for (unsigned int c = 0; c < COLS; c  ) {
            double val = 0.0;
            for (unsigned int rc = 0; rc < INNER; rc  ) {
                val  = inMatrix1.get_value(r,rc)*inMatrix2.get_value(rc,c);
            }
            outMatrix.set_value(r,c,val);
        }
    }
    return outMatrix;
}

// Multiply a MatrixSym and Matrix and output a Matrix
template<unsigned int ROWS, unsigned int COLS, unsigned int INNER>
Matrix<ROWS,COLS> multiply (MatrixSym<ROWS,INNER>amp; inMatrix1, Matrix<INNER,COLS>amp; inMatrix2) {
    //MatrixSym<ROWS,COLS> outMatrixSym;
    Matrix<ROWS,COLS> outMatrix;
    for (unsigned int r = 0; r < ROWS; r  ) {
        for (unsigned int c = 0; c < COLS; c  ) {
            double val = 0.0;
            for (unsigned int rc = 0; rc < INNER; rc  ) {
                val  = inMatrix1.get_value(r,rc)*inMatrix2.get_value(rc,c);
            }
            outMatrix.set_value(r,c,val);
        }
    }
    return outMatrix;
}

// Multiply a MatrixSym and MatrixSym and output a MatrixSym
template<unsigned int ROWS, unsigned int COLS, unsigned int INNER>
MatrixSym<ROWS,COLS> multiply (MatrixSym<ROWS,INNER>amp; inMatrix1, MatrixSym<INNER,COLS>amp; inMatrix2) {
    //MatrixSym<ROWS,COLS> outMatrixSym;
    MatrixSym<ROWS,COLS> outMatrix;
    for (unsigned int r = 0; r < ROWS; r  ) {
        for (unsigned int c = 0; c < COLS; c  ) {
            double val = 0.0;
            for (unsigned int rc = 0; rc < INNER; rc  ) {
                val  = inMatrix1.get_value(r,rc)*inMatrix2.get_value(rc,c);
            }
            outMatrix.set_value(r,c,val);
        }
    }
    return outMatrix;
}

int main()
{
    Matrix<3,3> A;
    MatrixSym<3> S;
    Matrix<3,3> AtimesA = multiply(A,A);
    Matrix<3,3> AtimesS = multiply(A,S);
    Matrix<3,3> StimesA = multiply(S,A);
    MatrixSym<3> StimesS = multiply(S,S);
    
    // Make sure that symmetric matrix S is indeed smaller than A
    std::cout << "sizeof(A)/sizeof(double) = " << sizeof(A)/sizeof(double) << std::endl;
    std::cout << "sizeof(S)/sizeof(double) = " << sizeof(S)/sizeof(double) << std::endl;
    return 0;
}
 

Выходы:

 sizeof(A)/sizeof(double) = 9
sizeof(S)/sizeof(double) = 15
 

Что я пробовал

  • Если я попытаюсь использовать функцию MatrixAbstract в качестве аргументов, мне нужно будет поиграть с параметрами шаблона, так как я не могу NUMEL использовать ее в качестве параметра шаблона. Кроме того, я не могу создать экземпляр a MatrixAbstract для возвращаемого значения.
  • Я не могу понять, как создать шаблон функции для или Matrix или MatrixSym . Моя догадка заключается в том, что это способ решить эту проблему, но я не понимаю, как создать шаблон функции для ввода шаблонного класса, но при этом иметь возможность использовать ROWS и COLS в качестве аргументов шаблона.
 template<class InMatrix1Type, class InMatrix2Type, class OutMatrixType, unsigned int ROWS, unsigned int COLS, unsigned int INNER>
OutMatrixType<ROWS,COLS> multiply (InMatrix1Type<ROWS,INNER>amp; inMatrix1, InMatrix2Type<INNER,COLS>amp; inMatrix2)
 

выдает мне ошибку компилятора, начинающуюся с

 error: ‘OutMatrixType’ is not a template
 OutMatrixType<ROWS,COLS> multiply (InMatrix1Type<ROWS,INNER>amp; inMatrix1, InMatrix2Type<INNER,COLS>amp; inMatrix2)
 ^~~~~~~~~~~~~
 

Комментарии:

1. Вы используете частичную специализацию по шаблонам — en.wikipedia.org/wiki/Partial_template_specialization ?

2. Наследование здесь-неправильный инструмент. Напишите свою multiply функцию в виде шаблона, который принимает два аргумента типа, по одному для каждой матрицы. Напишите «обычный» Matrix класс, который обеспечивает (если я правильно читаю ваш код с быстрым беглым просмотром) get_value(unsigned, unsigned) , который возвращает ссылку на значение этих индексов. Напишите второй SymmetricMatrix класс, который предоставляет get_value(unsigned, unsigned) , который возвращает ссылку на значение этих индексов. multiply не имеет значения, исходит ли полученное значение из полной матрицы или из симметричной (меньшей) матрицы.

3. Если вы не делаете это в учебных целях, рассмотрите возможность использования существующей библиотеки матриц.

4. @PeteBecker Matrix и SymmetricMatrix почти идентичны, за исключением способа хранения данных. Разве наследование не лучший метод для подобных ситуаций? Существует множество других функций-членов и атрибутов, которые будут иметь каждая матрица в моем коде, и у меня также есть третий специальный тип матрицы, который использует меньше памяти (в отличие от SymmetricMatrix . Я не хочу копировать и вставлять код для каждого типа матрицы.

5. Я использую старую библиотеку матриц Newmat11. Он действительно использует меньше памяти для SymmetricMatrix, чем для Matrix, но использует динамическое выделение памяти.

Ответ №1:

Функция multiply может быть шаблоном, который принимает два (возможно, разных) типа аргументов, при условии, что они поддерживают соответствующий интерфейс. Подобный этому:

 template <unsigned ROWS, unsigned COLS>
struct Matrix {
    double get_value(int row, int col) const;
    void set_value(int row, int col, double value);
};

template <unsigned ROWS, unsigned COLS>
struct Symmetric_Matrix {
    double get_value(int row, int col) const;
    void set_value(int row, int col, double value);
};

template <unsigned ROWS, unsigned COLS, unsigned INNER,
    template <unsigned, unsigned> class M1,
    template <unsigned, unsigned> class M2>
Matrix<ROWS, COLS> multiply(const M1<ROWS, INNER>amp; m1,
    const M2<INNER, COLS>amp; m2) {
    Matrix<ROWS, COLS> resu<
    for (unsigned r = 0; r < ROWS;   r)
        for (unsigned c = 0; c < COLS;   c) {
            double val = 0.0;
            for (unsigned rc = 0; rc < INNER;   rc)
                val  = m1.get_value(r, rc) * m2.get_value(rc, c);
        result.set_value(r, c, val);
        }
    return resu<
}
 

M1 и M2 являются параметрами шаблона шаблона; то есть это шаблоны, которые используются в качестве аргументов шаблона multiply ; компилятор определит их типы и соответствующие значения ROW COL , и INNER когда вызывается функция.

 int main() {
    Matrix<3, 5> res;

    Matrix<3, 4> x1;
    Matrix<4, 5> x2;
    res = multiply(x1, x2);

    Symmetric_Matrix<3, 4> sx1;
    Symmetric_Matrix<4, 5> sx2;
    res = multiply(sx1, sx2);

    res = multiply(x1, sx2);
    res = multiply(sx1, x2);

    return 0;
}
 

Конечно, реальный код предоставил бы реализации для get_value и set_value , но коду умножения все равно, как они реализованы, поэтому наличие двух разных типов будет работать просто отлично.

Ответ №2:

Поэтому я устранил много шума в вашем первоначальном вопросе, удалив конкретный код, относящийся к матрицам, и попытался свести его к тому, что было задано.

То, что я сделал здесь, заключалось в том, что я придерживался базового класса и вместо этого позволил пользователю указать тип возвращаемого значения, который они хотят для multiply функции.

Я не сильно проверял это, но, похоже, оно делает то, что вам нужно.

Чтобы уменьшить количество параметров шаблона, не относящихся к типу , в multiply функции, я включил несколько геттеров для row , col , и numel . Вы, конечно, можете отказаться от них и вернуться к тому, как у вас было раньше, но эти функции-члены позволят вам assert убедиться, что переданные параметры верны.

При всем при этом, как упоминал @Пит Беккер, вы также могли бы сделать это без наследования здесь. Прочтите его комментарий для получения дополнительной информации.

Это не полный пример, но он может помочь вам в окончательном решении.

 class MatrixBase {
public: 
    virtual double get_value( int row, int column ) const = 0;
    virtual void set_value( int row, int column, double value ) const = 0;
    virtual std::uint32_t get_rows( ) const = 0;
    virtual std::uint32_t get_cols( ) const = 0;
    virtual std::uint32_t get_numel( ) const = 0;
    virtual ~MatrixBase( ) = defau<
};

template<std::uint32_t Rows, std::uint32_t Cols, std::uint32_t Numel>
class Matrix : public MatrixBase {
public:
    double get_value( int row, int column ) const override { 
        return 1.0; 
    }

    void set_value( int row, int column, double value ) const override { }

    std::uint32_t get_rows( ) const override {
        return Rows;
    };

    std::uint32_t get_cols( ) const override {
        return Cols;
    }

    std::uint32_t get_numel( ) const override {
        return Numel;
    }

private:
    std::array<double, Numel> values_{ };
};

template<std::uint32_t Rows, std::uint32_t Cols, std::uint32_t Numel>
class MatrixSum : public MatrixBase {
public:
    double get_value( int row, int column ) const override {
        return 1.0;
    }

    void set_value( int row, int column, double value ) const override { }

    std::uint32_t get_rows( ) const override {
        return Rows;
    };

    std::uint32_t get_cols( ) const override {
        return Cols;
    }

    std::uint32_t get_numel( ) const override {
        return Numel;
    }

private:
    std::array<double, Numel> values_{ };
};


template<typename T, std::uint32_t Inner>
static T multiply( const MatrixBaseamp; m1, const MatrixBaseamp; m2 ) {
    static_assert( std::is_base_of_v<MatrixBase, T>, 
        "Return type must derive from MatrixBase" );

    static_assert( std::is_default_constructible_v<T>, 
        "Type must be default constructable" );

    T out{ };

    // Get the values.
    const auto m1_values{ m1.get_value( 1, 0 ) };
    const auto m2_values{ m2.get_value( 1, 0 ) };

    // Set the values on the new matrix.
    out.set_value( 1, 0, m1_values * m2_values );

    return out;
}

int main( ) {
    MatrixSum<10, 10, 10> matrix_sum{ };
    Matrix<10, 10, 10> matrix{ };

    auto m{ multiply<Matrix<10, 10, 10>, 10>( matrix_sum, matrix ) };
    auto ms{ multiply<MatrixSum<10, 10, 10>, 10>( matrix, matrix_sum ) };
}
 

Или, если вы используете C 20 , вы можете просто определить a concept с requires помощью предложения, в котором вы можете указать интерфейс, который должен иметь переданный тип. Итак, используя приведенные выше определения, которые могут выглядеть следующим образом.

 template<typename T>
concept Mat = std::is_default_constructible_v<T> amp;amp; 
requires( T m, int row, int col, double val ) {
    { m.get_rows( ) } -> std::same_as<std::uint32_t>;
    { m.get_cols( ) } -> std::same_as<std::uint32_t>;
    { m.get_numel( ) } -> std::same_as<std::uint32_t>;
    { m.get_value( row, col ) } -> std::same_as<double>;
    m.set_value( row, col, val );
};

template<std::uint32_t Inner, Mat T1, Mat T2, Mat T3>
static T1 multiply( const T2amp; m1, const T3amp; m2 ) {
    T1 out{ };

    // Get the values.
    const auto m1_values{ m1.get_value( 1, 0 ) };
    const auto m2_values{ m2.get_value( 1, 0 ) };

    // Set the values on the new matrix.
    out.set_value( 1, 0, m1_values * m2_values );

    return out;
}
 

При таком подходе вы можете полностью удалить MatrixBase class , если хотите, и просто ограничить multiply функцию типами, которые предоставляют желаемую функциональность.

Комментарии:

1. Спасибо @WBuck, здесь много хороших идей, которые я могу использовать. Это не совсем окончательное решение, которое я ищу (я не хочу указывать аргументы шаблона каждый раз при вызове multiply и другие матричные функции, так как это потребовало бы от меня изменения большого количества существующего кода), но это определенно приближает меня.

Ответ №3:

@PeteBecker ответ является лучшим решением, чтобы то, что я просил, но я буду размещать то, что я думаю, что я собираюсь использовать как еще одно решение здесь, поскольку оно использует тот факт, что произведение двух симметричных матриц может использовать меньше операций, чем стандартный матрица умножения, и я все еще хочу использовать базовый класс, поэтому мне не нужно дублировать код Matrix и MatrixSym (у них больше функций-членов, чем просто get_value а set_value )

 #include <array> 
#include <iostream> 

template<std::uint8_t ROWS, std::uint8_t COLS, std::uint16_t STORAGE_SIZE>
class MatrixBase {
private:
    virtual std::uint16_t index_from_rc(const std::uint8_tamp; row, const std::uint8_tamp; col) const = 0;
    std::array<double, STORAGE_SIZE> values_{ };
public: 
    // get the value at a given row and column
    double get_value(const std::uint8_tamp; row, const std::uint8_tamp; col) const
    {
        return this->values_[this->index_from_rc(row,col)];
    }
    // set the value, given the row and column
    void set_value(const std::uint8_tamp; row, const std::uint8_tamp; col, doubleamp; value)
    {
        this->values_[this->index_from_rc(row,col)] = value;
    }
};

template<std::uint8_t ROWS, std::uint8_t COLS>
class Matrix : public MatrixBase<ROWS, COLS, ROWS*COLS> {
private:
    std::uint16_t index_from_rc(const std::uint8_tamp; row, const std::uint8_tamp; col) const {
        return row*COLS   col; 
    }
};

// Symmetric matrix "MatrixSym": must be square as well, so only need to specify ROWS, since ROWS = COLS 
template<std::uint8_t ROWS>
class MatrixSym : public MatrixBase<ROWS, ROWS, ROWS*(ROWS 1)/2> {
private:
    std::uint16_t index_from_rc(const std::uint8_tamp; row, const std::uint8_tamp; col) const {
        std::uint8_t z;
        return ( ( z = ( row < col ? col : row ) ) * ( z   1 ) >> 1 )   ( col < row ? col : row ) ;
    }
};

// General function to multiply a MatrixBase and MatrixBase and output a Matrix
template<std::uint8_t ROWS, std::uint8_t COLS, std::uint8_t INNER, std::uint16_t STORAGE_SIZE1, std::uint16_t STORAGE_SIZE2>
Matrix<ROWS,COLS> operator*( const MatrixBase<ROWS,INNER,STORAGE_SIZE1>amp; inMatrix1, const MatrixBase<INNER,COLS,STORAGE_SIZE2>amp; inMatrix2){
    std::cout << "Multiplying two MatrixBase objects and outputting a Matrix" << std::endl;
    Matrix<ROWS,COLS> outMatrix;
    for (std::uint8_t r = 0; r < ROWS; r  ) {
        for (std::uint8_t c = 0; c < COLS; c  ) {
            double val = 0.0;
            for (std::uint8_t rc = 0; rc < INNER; rc  ) {
                val  = inMatrix1.get_value(r,rc)*inMatrix2.get_value(rc,c);
            }
            outMatrix.set_value(r,c,val);
        }
    }
    return outMatrix;
}

// Special case function: multiply a MatrixSym and MatrixSym and output a MatrixSym
template<std::uint8_t ROWS>
MatrixSym<ROWS> operator*( const MatrixSym<ROWS>amp; inMatrix1, const MatrixSym<ROWS>amp; inMatrix2){
    std::cout << "Multiplying two MatrixSym objects and outputting a MatrixSym" << std::endl;
    MatrixSym<ROWS> outMatrix;
    for (std::uint8_t r = 0; r < ROWS; r  ) {
        for (std::uint8_t c = r; c < ROWS; c  ) { // improve efficiency by starting c at r (instead of 0) for symmetric matrices
            double val = 0.0;
            for (std::uint8_t rc = 0; rc < ROWS; rc  ) {
                val  = inMatrix1.get_value(r,rc)*inMatrix2.get_value(rc,c);
            }
            outMatrix.set_value(r,c,val);
        }
    }
    return outMatrix;
}

int main( ) {
    MatrixSym<3> S;
    Matrix<3,3> A;

    std::cout << "sizeof(A) = " << sizeof(A) << std::endl;
    std::cout << "sizeof(S) = " << sizeof(S) << std::endl;

    std::cout << "Calculating A*A" << std::endl;
    auto AtimesA = A*A; // auto = Matrix
    std::cout << "Calculating S*A" << std::endl;
    auto StimesA = S*A; // auto = Matrix
    std::cout << "Calculating A*S" << std::endl;
    auto AtimesS = A*S; // auto = Matrix
    std::cout << "Calculating S*S" << std::endl;
    auto StimesS = S*S; // auto = MatrixSym
    std::cout << "Calculating S*S*A" << std::endl;
    auto StimesStimesA = S*S*A; // auto = MatrixSym
}
 

какие результаты:

 sizeof(A) = 80
sizeof(S) = 56
Calculating A*A
Multiplying two MatrixBase objects and outputting a Matrix
Calculating S*A
Multiplying two MatrixBase objects and outputting a Matrix
Calculating A*S
Multiplying two MatrixBase objects and outputting a Matrix
Calculating S*S
Multiplying two MatrixSym objects and outputting a MatrixSym
Calculating S*S*A
Multiplying two MatrixSym objects and outputting a MatrixSym
Multiplying two MatrixBase objects and outputting a Matrix