Commit 4cff92bb by Maarten L. Hekkelman

symmetry operations now working correctly

parent 9aa8a223
...@@ -347,13 +347,44 @@ auto operator-(const matrix_expression<M1> &m1, const matrix_expression<M2> &m2) ...@@ -347,13 +347,44 @@ auto operator-(const matrix_expression<M1> &m1, const matrix_expression<M2> &m2)
return matrix_subtraction(m1, m2); return matrix_subtraction(m1, m2);
} }
template <typename M, typename F> template <typename M1, typename M2>
class matrix_multiplication : public matrix_expression<matrix_multiplication<M, F>> class matrix_matrix_multiplication : public matrix_expression<matrix_matrix_multiplication<M1, M2>>
{ {
public: public:
using value_type = F; matrix_matrix_multiplication(const M1 &m1, const M2 &m2)
: m_m1(m1)
, m_m2(m2)
{
assert(m1.dim_m() == m2.dim_n());
}
constexpr uint32_t dim_m() const { return m_m1.dim_m(); }
constexpr uint32_t dim_n() const { return m_m1.dim_n(); }
constexpr auto operator()(uint32_t i, uint32_t j) const
{
using value_type = decltype(m_m1(0, 0));
value_type result = {};
for (uint32_t k = 0; k < m_m1.dim_m(); ++k)
result += m_m1(i, k) * m_m2(k, j);
matrix_multiplication(const M &m, value_type v) return result;
}
private:
const M1 &m_m1;
const M2 &m_m2;
};
template<typename M, typename T>
class matrix_scalar_multiplication : public matrix_expression<matrix_scalar_multiplication<M, T>>
{
public:
using value_type = T;
matrix_scalar_multiplication(const M &m, value_type v)
: m_m(m) : m_m(m)
, m_v(v) , m_v(v)
{ {
...@@ -372,10 +403,17 @@ class matrix_multiplication : public matrix_expression<matrix_multiplication<M, ...@@ -372,10 +403,17 @@ class matrix_multiplication : public matrix_expression<matrix_multiplication<M,
value_type m_v; value_type m_v;
}; };
template <typename M, typename F>
auto operator*(const matrix_expression<M> &m, F v) template <typename M1, typename T, std::enable_if_t<std::is_floating_point_v<T>, int> = 0>
auto operator*(const matrix_expression<M1> &m, T v)
{
return matrix_scalar_multiplication(m, v);
}
template <typename M1, typename M2, std::enable_if_t<not std::is_floating_point_v<M2>, int> = 0>
auto operator*(const matrix_expression<M1> &m1, const matrix_expression<M2> &m2)
{ {
return matrix_multiplication(m, v); return matrix_matrix_multiplication(m1, m2);
} }
// -------------------------------------------------------------------- // --------------------------------------------------------------------
......
...@@ -210,14 +210,14 @@ class row_handle ...@@ -210,14 +210,14 @@ class row_handle
return detail::get_row_result<C...>(*this, { get_column_ix(columns)... }); return detail::get_row_result<C...>(*this, { get_column_ix(columns)... });
} }
template <typename... Ts, typename... C, std::enable_if_t<sizeof...(Ts) == sizeof...(C), int> = 0> template <typename... Ts, typename... C, std::enable_if_t<sizeof...(Ts) == sizeof...(C) and sizeof...(C) != 1, int> = 0>
std::tuple<Ts...> get(C... columns) const std::tuple<Ts...> get(C... columns) const
{ {
return detail::get_row_result<Ts...>(*this, { get_column_ix(columns)... }); return detail::get_row_result<Ts...>(*this, { get_column_ix(columns)... });
} }
template <typename T> template <typename T>
T get(const char *column) T get(const char *column) const
{ {
return operator[](get_column_ix(column)).template as<T>(); return operator[](get_column_ix(column)).template as<T>();
} }
......
...@@ -39,6 +39,17 @@ namespace cif ...@@ -39,6 +39,17 @@ namespace cif
// -------------------------------------------------------------------- // --------------------------------------------------------------------
inline point operator*(const matrix3x3<float> &m, const point &pt)
{
return {
m(0, 0) * pt.m_x + m(0, 1) * pt.m_y + m(0, 2) * pt.m_z,
m(1, 0) * pt.m_x + m(1, 1) * pt.m_y + m(1, 2) * pt.m_z,
m(2, 0) * pt.m_x + m(2, 1) * pt.m_y + m(2, 2) * pt.m_z
};
}
// --------------------------------------------------------------------
enum class space_group_name enum class space_group_name
{ {
full, full,
...@@ -195,7 +206,7 @@ class cell ...@@ -195,7 +206,7 @@ class cell
}; };
/// @brief A class that encapsulates the symmetry operations as used in PDB files, i.e. a rotational number and a translation vector /// @brief A class that encapsulates the symmetry operations as used in PDB files, i.e. a rotational number and a translation vector
class sym_op struct sym_op
{ {
public: public:
sym_op(uint8_t nr = 1, uint8_t ta = 5, uint8_t tb = 5, uint8_t tc = 5) sym_op(uint8_t nr = 1, uint8_t ta = 5, uint8_t tb = 5, uint8_t tc = 5)
...@@ -225,13 +236,12 @@ class sym_op ...@@ -225,13 +236,12 @@ class sym_op
std::string string() const; std::string string() const;
friend class spacegroup;
private:
uint8_t m_nr; uint8_t m_nr;
uint8_t m_ta, m_tb, m_tc; uint8_t m_ta, m_tb, m_tc;
}; };
static_assert(sizeof(sym_op) == 4, "Sym_op should be four bytes");
namespace literals namespace literals
{ {
inline sym_op operator""_symop(const char *text, size_t length) inline sym_op operator""_symop(const char *text, size_t length)
...@@ -257,6 +267,16 @@ class transformation ...@@ -257,6 +267,16 @@ class transformation
point operator()(const cell &c, const point &pt) const; point operator()(const cell &c, const point &pt) const;
point operator()(const point &pt) const
{
return m_rotation * pt + m_translation;
}
friend transformation operator*(const transformation &lhs, const transformation &rhs);
friend transformation inverse(const transformation &t);
friend class spacegroup;
private: private:
matrix3x3<float> m_rotation; matrix3x3<float> m_rotation;
point m_translation; point m_translation;
...@@ -264,12 +284,19 @@ class transformation ...@@ -264,12 +284,19 @@ class transformation
// -------------------------------------------------------------------- // --------------------------------------------------------------------
int get_space_group_number(std::string_view spacegroup); // alternative for clipper's parsing code, using space_group_name::full int get_space_group_number(const datablock &db);
int get_space_group_number(std::string_view spacegroup, space_group_name type); // alternative for clipper's parsing code int get_space_group_number(std::string_view spacegroup);
int get_space_group_number(std::string_view spacegroup, space_group_name type);
class spacegroup : public std::vector<transformation> class spacegroup : public std::vector<transformation>
{ {
public: public:
spacegroup(const datablock &db)
: spacegroup(get_space_group_number(db))
{
}
spacegroup(std::string_view name) spacegroup(std::string_view name)
: spacegroup(get_space_group_number(name)) : spacegroup(get_space_group_number(name))
{ {
...@@ -285,7 +312,7 @@ class spacegroup : public std::vector<transformation> ...@@ -285,7 +312,7 @@ class spacegroup : public std::vector<transformation>
int get_nr() const { return m_nr; } int get_nr() const { return m_nr; }
std::string get_name() const; std::string get_name() const;
point operator()(const point &pt, const cell &c, sym_op symop); point operator()(const point &pt, const cell &c, sym_op symop) const;
private: private:
int m_nr; int m_nr;
...@@ -293,36 +320,6 @@ class spacegroup : public std::vector<transformation> ...@@ -293,36 +320,6 @@ class spacegroup : public std::vector<transformation>
}; };
// -------------------------------------------------------------------- // --------------------------------------------------------------------
int get_space_group_number(std::string spacegroup); // alternative for clipper's parsing code, using space_group_name::full
int get_space_group_number(std::string spacegroup, space_group_name type); // alternative for clipper's parsing code
// class rtop
// {
// public:
// rtop(const spacegroup &sg, const cell &c, int nr);
// friend rtop operator+(rtop rt, cif::point t);
// friend cif::point operator*(cif::point p, rtop rt);
// private:
// cell m_c;
// cif::quaternion m_q;
// cif::point m_t;
// };
// class spacegroup
// {
// public:
// spacegroup(const cif::datablock &db);
// private:
// std::vector<
// };
static_assert(sizeof(sym_op) == 4, "Sym_op should be four bytes");
// --------------------------------------------------------------------
// Symmetry operations on points // Symmetry operations on points
template <typename T> template <typename T>
...@@ -343,4 +340,24 @@ inline point fractional(const point &pt, const cell &c) ...@@ -343,4 +340,24 @@ inline point fractional(const point &pt, const cell &c)
return c.get_fractional_matrix() * pt; return c.get_fractional_matrix() * pt;
} }
inline transformation orthogonal(const transformation &t, const cell &c)
{
return transformation(c.get_orthogonal_matrix(), {}) * t * transformation(c.get_fractional_matrix(), {});
}
inline transformation fractional(const transformation &t, const cell &c)
{
return transformation(c.get_fractional_matrix(), {}) * t * transformation(c.get_orthogonal_matrix(), {});
}
// --------------------------------------------------------------------
inline point symmetry_copy(const point &pt, const spacegroup &sg, const cell &c, sym_op symop)
{
return sg(pt, c, symop);
}
std::tuple<float,point,sym_op> closest_symmetry_copy(const spacegroup &sg, const cell &c, point a, point b);
} // namespace cif } // namespace cif
...@@ -52,7 +52,7 @@ cell::cell(const datablock &db) ...@@ -52,7 +52,7 @@ cell::cell(const datablock &db)
{ {
auto &_cell = db["cell"]; auto &_cell = db["cell"];
cif::tie(m_a, m_b, m_c, m_alpha, m_beta, m_gamma) = tie(m_a, m_b, m_c, m_alpha, m_beta, m_gamma) =
_cell.front().get("length_a", "length_b", "length_c", "angle_alpha", "angle_beta", "angle_gamma"); _cell.front().get("length_a", "length_b", "length_c", "angle_alpha", "angle_beta", "angle_gamma");
init(); init();
...@@ -98,6 +98,22 @@ sym_op::sym_op(std::string_view s) ...@@ -98,6 +98,22 @@ sym_op::sym_op(std::string_view s)
throw std::invalid_argument("Could not convert string into sym_op"); throw std::invalid_argument("Could not convert string into sym_op");
} }
std::string sym_op::string() const
{
char b[9];
auto r = std::to_chars(b, b + sizeof(b), m_nr);
if (r.ec != std::errc() or r.ptr > b + 4)
throw std::runtime_error("Could not write out symmetry operation to string");
*r.ptr++ = '_';
*r.ptr++ = '0' + m_ta;
*r.ptr++ = '0' + m_tb;
*r.ptr++ = '0' + m_tc;
*r.ptr = 0;
return { b, r.ptr - b };
}
// -------------------------------------------------------------------- // --------------------------------------------------------------------
transformation::transformation(const symop_data &data) transformation::transformation(const symop_data &data)
...@@ -116,7 +132,24 @@ transformation::transformation(const symop_data &data) ...@@ -116,7 +132,24 @@ transformation::transformation(const symop_data &data)
m_translation.m_x = d[9] == 0 ? 0 : 1.0 * d[9] / d[10]; m_translation.m_x = d[9] == 0 ? 0 : 1.0 * d[9] / d[10];
m_translation.m_y = d[11] == 0 ? 0 : 1.0 * d[11] / d[12]; m_translation.m_y = d[11] == 0 ? 0 : 1.0 * d[11] / d[12];
m_translation.m_y = d[13] == 0 ? 0 : 1.0 * d[13] / d[14]; m_translation.m_z = d[13] == 0 ? 0 : 1.0 * d[13] / d[14];
}
transformation operator*(const transformation &lhs, const transformation &rhs)
{
auto r = lhs.m_rotation * rhs.m_rotation;
auto t = lhs.m_rotation * rhs.m_translation;
t = t + lhs.m_translation;
return transformation(r, t);
// return transformation(lhs.m_rotation * rhs.m_rotation, lhs.m_rotation * rhs.m_translation + lhs.m_translation);
}
transformation inverse(const transformation &t)
{
auto inv_matrix = inverse(t.m_rotation);
return { inv_matrix, -(inv_matrix * t.m_translation) };
} }
// -------------------------------------------------------------------- // --------------------------------------------------------------------
...@@ -124,12 +157,12 @@ transformation::transformation(const symop_data &data) ...@@ -124,12 +157,12 @@ transformation::transformation(const symop_data &data)
spacegroup::spacegroup(int nr) spacegroup::spacegroup(int nr)
: m_nr(nr) : m_nr(nr)
{ {
const size_t N = cif::kSymopNrTableSize; const size_t N = kSymopNrTableSize;
int32_t L = 0, R = static_cast<int32_t>(N - 1); int32_t L = 0, R = static_cast<int32_t>(N - 1);
while (L <= R) while (L <= R)
{ {
int32_t i = (L + R) / 2; int32_t i = (L + R) / 2;
if (cif::kSymopNrTable[i].spacegroup() < m_nr) if (kSymopNrTable[i].spacegroup() < m_nr)
L = i + 1; L = i + 1;
else else
R = i - 1; R = i - 1;
...@@ -137,8 +170,8 @@ spacegroup::spacegroup(int nr) ...@@ -137,8 +170,8 @@ spacegroup::spacegroup(int nr)
m_index = L; m_index = L;
for (size_t i = L; i < N and cif::kSymopNrTable[i].spacegroup() == m_nr; ++i) for (size_t i = L; i < N and kSymopNrTable[i].spacegroup() == m_nr; ++i)
emplace_back(cif::kSymopNrTable[i].symop().data()); emplace_back(kSymopNrTable[i].symop().data());
} }
std::string spacegroup::get_name() const std::string spacegroup::get_name() const
...@@ -152,14 +185,60 @@ std::string spacegroup::get_name() const ...@@ -152,14 +185,60 @@ std::string spacegroup::get_name() const
throw std::runtime_error("Spacegroup has an invalid number: " + std::to_string(m_nr)); throw std::runtime_error("Spacegroup has an invalid number: " + std::to_string(m_nr));
} }
point spacegroup::operator()(const point &pt, const cell &c, sym_op symop) point offsetToOrigin(const cell &c, const point &p)
{
point d{};
while (p.m_x + d.m_x < (c.get_a() / 2))
d.m_x += c.get_a();
while (p.m_x + d.m_x > (c.get_a() / 2))
d.m_x -= c.get_a();
while (p.m_y + d.m_y < (c.get_b() / 2))
d.m_y += c.get_b();
while (p.m_y + d.m_y > (c.get_b() / 2))
d.m_y -= c.get_b();
while (p.m_z + d.m_z < (c.get_c() / 2))
d.m_z += c.get_c();
while (p.m_z + d.m_z > (c.get_c() / 2))
d.m_z -= c.get_c();
return d;
};
std::tuple<int,int,int> offsetToOriginInt(const cell &c, const point &p)
{
auto o = offsetToOrigin(c, p);
return {
std::rintf(o.m_x / c.get_a()),
std::rintf(o.m_y / c.get_b()),
std::rintf(o.m_z / c.get_c())
};
}
point spacegroup::operator()(const point &pt, const cell &c, sym_op symop) const
{ {
if (symop.m_nr < 1 or symop.m_nr > size()) if (symop.m_nr < 1 or symop.m_nr > size())
throw std::out_of_range("symmetry operator number out of range"); throw std::out_of_range("symmetry operator number out of range");
transformation t = at(symop.m_nr - 1); transformation t = at(symop.m_nr - 1);
return pt; t.m_translation.m_x += symop.m_ta - 5;
t.m_translation.m_y += symop.m_tb - 5;
t.m_translation.m_z += symop.m_tc - 5;
auto t_orth = orthogonal(t, c);
auto o = offsetToOrigin(c, pt);
transformation tlo(identity_matrix<float>(3), o);
auto itlo = inverse(tlo);
point result = pt + o;
result = t_orth(result);
result = itlo(result);
return result;
} }
// -------------------------------------------------------------------- // --------------------------------------------------------------------
...@@ -279,4 +358,84 @@ int get_space_group_number(std::string_view spacegroup, space_group_name type) ...@@ -279,4 +358,84 @@ int get_space_group_number(std::string_view spacegroup, space_group_name type)
return result; return result;
} }
int get_space_group_number(const datablock &db)
{
auto &_symmetry = db["symmetry"];
if (_symmetry.size() != 1)
throw std::runtime_error("Could not find a unique symmetry in this mmCIF file");
return _symmetry.front().get<int>("Int_Tables_number");
}
// --------------------------------------------------------------------
std::tuple<float,point,sym_op> closest_symmetry_copy(const spacegroup &sg, const cell &c, point a, point b)
{
if (c.get_a() == 0 or c.get_b() == 0 or c.get_c() == 0)
throw std::runtime_error("Invalid cell, contains a dimension that is zero");
point result_p;
float result_d = std::numeric_limits<float>::max();
sym_op result_s;
auto fa = fractional(a, c);
auto fb = fractional(b, c);
for (size_t i = 0; i < sg.size(); ++i)
{
sym_op s(i + 1);
auto &t = sg[i];
auto fsb = t(fb);
while (fsb.m_x - 0.5f > fa.m_x)
{
fsb.m_x -= 1;
s.m_ta -= 1;
}
while (fsb.m_x + 0.5f < fa.m_x)
{
fsb.m_x += 1;
s.m_ta += 1;
}
while (fsb.m_y - 0.5f > fa.m_y)
{
fsb.m_y -= 1;
s.m_tb -= 1;
}
while (fsb.m_y + 0.5f < fa.m_y)
{
fsb.m_y += 1;
s.m_tb += 1;
}
while (fsb.m_z - 0.5f > fa.m_z)
{
fsb.m_z -= 1;
s.m_tc -= 1;
}
while (fsb.m_z + 0.5f < fa.m_z)
{
fsb.m_z += 1;
s.m_tc += 1;
}
auto p = orthogonal(fsb, c);
auto dsq = distance_squared(a, p);
if (result_d > dsq)
{
result_d = dsq;
result_p = p;
result_s = s;
}
}
return { std::sqrt(result_d), result_p, result_s };
}
} // namespace cif } // namespace cif
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include <cif++.hpp> #include <cif++.hpp>
namespace tt = boost::test_tools; namespace tt = boost::test_tools;
namespace utf = boost::unit_test;
std::filesystem::path gTestDir = std::filesystem::current_path(); // filled in first test std::filesystem::path gTestDir = std::filesystem::current_path(); // filled in first test
...@@ -288,7 +289,7 @@ BOOST_AUTO_TEST_CASE(symm_3) ...@@ -288,7 +289,7 @@ BOOST_AUTO_TEST_CASE(symm_3)
BOOST_TEST(sg.get_name() == "P 21 21 2"); BOOST_TEST(sg.get_name() == "P 21 21 2");
} }
BOOST_AUTO_TEST_CASE(symm_4) BOOST_AUTO_TEST_CASE(symm_4, *utf::tolerance(0.1f))
{ {
using namespace cif::literals; using namespace cif::literals;
...@@ -311,3 +312,51 @@ BOOST_AUTO_TEST_CASE(symm_4) ...@@ -311,3 +312,51 @@ BOOST_AUTO_TEST_CASE(symm_4)
BOOST_TEST(distance(a, sb2) == 7.42f); BOOST_TEST(distance(a, sb2) == 7.42f);
} }
BOOST_AUTO_TEST_CASE(symm_2bi3_1, *utf::tolerance(0.1f))
{
cif::file f(gTestDir / "2bi3.cif.gz");
auto &db = f.front();
cif::mm::structure s(db);
cif::spacegroup sg(db);
cif::cell c(db);
auto struct_conn = db["struct_conn"];
for (const auto &[
asym1, seqid1, authseqid1, atomid1, symm1,
asym2, seqid2, authseqid2, atomid2, symm2,
dist] : struct_conn.find<
std::string,int,std::string,std::string,std::string,
std::string,int,std::string,std::string,std::string,
float>(
cif::key("ptnr1_symmetry") != "1_555" or cif::key("ptnr2_symmetry") != "1_555",
"ptnr1_label_asym_id", "ptnr1_label_seq_id", "ptnr1_auth_seq_id", "ptnr1_label_atom_id", "ptnr1_symmetry",
"ptnr2_label_asym_id", "ptnr2_label_seq_id", "ptnr2_auth_seq_id", "ptnr2_label_atom_id", "ptnr2_symmetry",
"pdbx_dist_value"
))
{
auto &r1 = s.get_residue(asym1, seqid1, authseqid1);
auto &r2 = s.get_residue(asym2, seqid2, authseqid2);
auto a1 = r1.get_atom_by_atom_id(atomid1);
auto a2 = r2.get_atom_by_atom_id(atomid2);
auto sa1 = symmetry_copy(a1.get_location(), sg, c, cif::sym_op(symm1));
auto sa2 = symmetry_copy(a2.get_location(), sg, c, cif::sym_op(symm2));
BOOST_TEST(cif::distance(sa1, sa2) == dist);
auto pa1 = a1.get_location();
const auto &[d, p, so] = cif::closest_symmetry_copy(sg, c, pa1, a2.get_location());
BOOST_TEST(p.m_x == sa2.m_x);
BOOST_TEST(p.m_y == sa2.m_y);
BOOST_TEST(p.m_z == sa2.m_z);
BOOST_TEST(d == dist);
BOOST_TEST(so.string() == symm2);
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment