1 module dvec.matrix;
2 
3 import std.traits : isNumeric, isFloatingPoint;
4 import dvec.vector;
5 
6 struct Mat(T, size_t rowCount, size_t colCount) if (isNumeric!T && rowCount > 0 && colCount > 0) {
7     public T[rowCount * colCount] data;
8 
9     public this(T[rowCount * colCount] elements) {
10         data[0 .. $] = elements[0 .. $];
11     }
12 
13     public this(T[] elements...) {
14         data[0 .. $] = elements[0 .. $];
15     }
16 
17     public this(Mat!(T, rowCount, colCount) other) {
18         static foreach (i; 0 .. data.length) data[i] = other.data[i];
19     }
20 
21     public this(T value) {
22         static foreach (i; 0 .. data.length) data[i] = value;
23     }
24 
25     private size_t convertToIndex(size_t i, size_t j) {
26         return colCount * i + j;
27     }
28 
29     public T opIndex(size_t i, size_t j) {
30         return data[convertToIndex(i, j)];
31     }
32 
33     public void opIndexAssign(T value, size_t i, size_t j) {
34         data[convertToIndex(i, j)] = value;
35     }
36 
37     public Vec!(T, colCount) getRow(size_t row) {
38         size_t idx = convertToIndex(row, 0);
39         return Vec!(T, colCount)(data[idx .. idx + colCount]);
40     }
41 
42     public void setRow(size_t row, Vec!(T, colCount) vector) {
43         size_t idx = convertToIndex(row, 0);
44         data[idx .. idx + colCount] = vector.data;
45     }
46 
47     public Vec!(T, rowCount) getCol(size_t col) {
48         Vec!(T, rowCount) v;
49         static foreach (i; 0 .. rowCount) {
50             v[i] = this[i, col];
51         }
52         return v;
53     }
54 
55     public void setCol(size_t col, Vec!(T, rowCount) vector) {
56         static foreach (i; 0 .. rowCount) {
57             this[i, col] = vector[i];
58         }
59     }
60 
61     public void add(Mat!(T, rowCount, colCount) other) {
62         static foreach (i; 0 .. data.length) data[i] += other.data[i];
63     }
64 
65     public void sub(Mat!(T, rowCount, colCount) other) {
66         static foreach (i; 0 .. data.length) data[i] -= other.data[i];
67     }
68 
69     public void mul(T factor) {
70         static foreach (i; 0 .. data.length) data[i] *= factor;
71     }
72 
73     public void div(T factor) {
74         static foreach (i; 0 .. data.length) data[i] /= factor;
75     }
76 
77     public Mat!(T, colCount, rowCount) transpose() {
78         Mat!(T, colCount, rowCount) m;
79         static foreach (i; 0 .. rowCount) {
80             static foreach (j; 0 .. colCount) {
81                 m[j, i] = this[i, j];
82             }
83         }
84         return m;
85     }
86 
87     /** 
88      * Computes the matrix multiplication of `this * other`.
89      * Params:
90      *   other = The matrix to multiply with this one.
91      * Returns: The resultant matrix.
92      */
93     public Mat!(T, rowCount, otherColCount) mul(T, size_t otherRowCount, size_t otherColCount)
94         (Mat!(T, otherRowCount, otherColCount) other) {
95         Mat!(T, rowCount, otherColCount) m;
96         T sum;
97         static foreach (i; 0 .. rowCount) {
98             static foreach (j; 0 .. otherColCount) {
99                 sum = 0;
100                 static foreach (k; 0 .. colCount) {
101                     sum += this[i, k] * other[k, j];
102                 }
103                 m[i, j] = sum;
104             }
105         }
106         return m;
107     }
108 
109     /** 
110      * Multiplies a vector against this matrix.
111      * Params:
112      *   vector = The vector to multiply.
113      * Returns: The resultant transformed vector.
114      */
115     public Vec!(T, rowCount) mul(Vec!(T, colCount) vector) {
116         Vec!(T, rowCount) result;
117         T sum;
118         static foreach (i; 0 .. rowCount) {
119             sum = 0;
120             static foreach (j; 0 .. colCount) {
121                 sum += this[i, j] * vector[j];
122             }
123             result[i] = sum;
124         }
125         return result;
126     }
127 
128     public void rowSwitch(size_t rowI, size_t rowJ) {
129         auto r = getRow(rowI);
130         setRow(rowJ, r);
131     }
132 
133     public void rowMultiply(size_t row, T factor) {
134         size_t idx = convertToIndex(row, 0);
135         static foreach (i; 0 .. colCount) {
136             data[idx + i] *= factor;
137         }
138     }
139 
140     public void rowAdd(size_t rowI, T factor, size_t rowJ) {
141         auto row = getRow(rowJ);
142         row.mul(factor);
143         setRow(rowI, row);
144     }
145 
146     public Mat!(T, rowCount - n, colCount - m) subMatrix(size_t n, size_t m)(size_t[n] rows, size_t[m] cols)
147         if (rowCount - n > 0 && colCount - m > 0) {
148         // TODO: Improve efficiency with static stuff.
149         Mat!(T, rowCount - n, colCount - m) sub;
150         size_t subIdx = 0;
151         foreach (idx; 0 .. data.length) {
152             size_t row = idx / colCount;
153             size_t col = idx % colCount;
154             bool skip = false;
155             foreach (r; rows) {
156                 if (r == row) {
157                     skip = true;
158                     break;
159                 }
160             }
161             if (!skip) {
162                 foreach (c; cols) {
163                     if (c == col) {
164                         skip = true;
165                         break;
166                     }
167                 }
168             }
169             if (!skip) {
170                 sub.data[subIdx++] = this[row, col];
171             }
172         }
173         return sub;
174     }
175 
176     static if (rowCount == colCount) {
177         alias N = rowCount;
178 
179         public static Mat!(T, N, N) identity() {
180             Mat!(T, N, N) m;
181             static foreach (i; 0 .. N) {
182                 static foreach (j; 0 .. N) {
183                     m[i, j] = i == j ? 1 : 0;
184                 }
185             }
186             return m;
187         }
188 
189         public T det() {
190             static if (N == 1) {
191                 return data[0];
192             } else static if (N == 2) {
193                 return data[0] * data[3] - data[1] * data[2];
194             } else {
195                 // Laplace expansion, taking i = 0.
196                 T sum = 0;
197                 static foreach (j; 0 .. N) {
198                     sum += (j % 2 == 0 ? 1 : -1) * this[0, j] * this.subMatrix([0], [j]).det();
199                 }
200                 return sum;
201             }
202         }
203 
204         public bool invertible() {
205             return det() != 0;
206         }
207 
208         public Mat!(T, N, N) cofactor() {
209             static if (N == 1) {
210                 return Mat!(T, N, N)(data[0]);
211             } else {
212                 Mat!(T, N, N) c;
213                 static foreach (i; 0 .. N) {
214                     static foreach (j; 0 .. N) {
215                         c[i, j] = ((i + j) % 2 == 0 ? 1 : -1) * this.subMatrix([i], [j]).det();
216                     }
217                 }
218                 return c;
219             }
220         }
221 
222         public Mat!(T, N, N) adjugate() {
223             return cofactor().transpose();
224         }
225 
226         public Mat!(T, N, N) inv() {
227             auto m = adjugate();
228             m.div(det());
229             return m;
230         }
231     }
232 }
233 
234 // Aliases for common matrix types.
235 alias Mat2f = Mat!(float, 2, 2);
236 alias Mat3f = Mat!(float, 3, 3);
237 alias Mat4f = Mat!(float, 4, 4);
238 
239 alias Mat2d = Mat!(double, 2, 2);
240 alias Mat3d = Mat!(double, 3, 3);
241 alias Mat4d = Mat!(double, 4, 4);
242 
243 unittest {
244     import std.stdio;
245     import dvec.vector;
246 
247     auto m1 = Mat3d();
248     assert(m1.data.length == 9);
249 
250     auto m2 = Mat!(double, 2, 3)([1, 2, 3, 0, -6, 7]);
251     auto m3 = m2.transpose();
252     assert(m3.data == [1, 0, 2, -6, 3, 7]);
253 
254     auto m4 = Mat2d([1, 2, 3, 4]);
255     assert(m4.getRow(0).data == [1, 2]);
256     assert(m4.getRow(1).data == [3, 4]);
257     assert(m4.getCol(0).data == [1, 3]);
258     assert(m4.getCol(1).data == [2, 4]);
259     auto m5 = m4.mul(Mat2d([0, 1, 0, 0]));
260     assert(m5.data == [0, 1, 0, 3]);
261 
262     auto m6 = Mat2d.identity();
263     assert(m6.data == [1, 0, 0, 1]);
264     auto m7 = Mat3d.identity();
265     assert(m7.data == [1, 0, 0, 0, 1, 0, 0, 0, 1]);
266 
267     auto m8 = Mat!(double, 2, 3)([1, -1, 2, 0, -3, 1]);
268     Vec3d v1 = Vec3d(2, 1, 0);
269     assert(m8.mul(v1).data == [1, -3]);
270 
271     auto m9 = Mat!(double, 3, 4)([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]);
272     auto m10 = m9.subMatrix([2], [1]);
273     assert(m10.data == [1, 3, 4, 5, 7, 8]);
274 
275     assert(Mat2d([3, 7, 1, -4]).det == -19);
276     assert(Mat2d([1, 2, 3, 4]).det == -2);
277     assert(Mat3d([1, 2, 3, 4, 5, 6, 7, 8, 9]).det == 0);
278 
279     assert(Mat2d.identity().inv() == Mat2d.identity());
280     assert(Mat3f.identity().inv() == Mat3f.identity());
281     assert(Mat2d(4, 7, 2, 6).inv() == Mat2d(0.6, -0.7, -0.2, 0.4));
282     assert(Mat2d(-3, 1, 5, -2).inv() == Mat2d(-2, -1, -5, -3));
283     assert(Mat3d(1, 3, 3, 1, 4, 3, 1, 3, 4).inv() == Mat3d(7, -3, -3, -1, 1, 0, -1, 0, 1));
284 }