Source code for ase2sprkkr.asr.test.test_database_rmsd

import pytest
from ase import Atoms


[docs] @pytest.mark.ci def test_database_rmsd_duplicates(duplicates_test_db): """Test that the duplicates (with rmsd=0)are correctly identified.""" from asr.database.rmsd import main nmat = len(duplicates_test_db) results = main('duplicates.db', 'duplicates-rmsd.db') rmsd_by_id = results['rmsd_by_id'] assert set(range(1, nmat + 1)).issubset(set(rmsd_by_id.keys())) for i in range(1, nmat + 1): keys = set([j for j in range(1, nmat + 1) if j != i]) assert keys.issubset(set(rmsd_by_id[i].keys())) for j in keys: assert rmsd_by_id[i][j] == pytest.approx(0)
[docs] @pytest.mark.ci def test_database_rmsd_duplicates_comparison_keys(duplicates_test_db): """Test that the duplicates (with rmsd=0)are correctly identified.""" from asr.database.rmsd import main results = main('duplicates.db', 'duplicates-rmsd.db', comparison_keys='magstate') rmsd_by_id = results['rmsd_by_id'] assert set(rmsd_by_id.keys()) == set([1, 3, 4, 5, 6])
[docs] @pytest.mark.ci @pytest.mark.parametrize('angle', [30]) @pytest.mark.parametrize('vector', ['x', 'y', 'z']) def test_database_rmsd_rotation(test_material, angle, vector): """Test that rmsd=0 when comparing rotated structures.""" from asr.database.rmsd import get_rmsd atoms = test_material.copy() atoms.rotate(angle, v=vector, rotate_cell=True) rmsd = get_rmsd(test_material, atoms) assert rmsd == pytest.approx(0)
[docs] @pytest.mark.ci def test_database_rmsd_none_handling(asr_tmpdir): """Test that handling of "None" from get_rmsd is correct.""" from .materials import Si from ase.db import connect from asr.database.rmsd import main db = connect('duplicates_rattled.db') db.write(Si) rattled = Si.copy() rattled.rattle(1.0) db.write(rattled) main('duplicates_rattled.db', 'duplicates_rattled_rmsd.db', max_rmsd=0.001)
[docs] def rattle_atoms(atoms, scale=0.01, seed=42): import numpy as np rng = np.random.RandomState(seed) pos = atoms.get_positions() dir_av = rng.normal(scale=scale, size=pos.shape) dir_av /= np.linalg.norm(dir_av, axis=1)[:, None] atoms.set_positions(pos + dir_av * scale) return atoms
[docs] @pytest.mark.ci def test_database_rmsd_rattled(test_material): """Test that rattled structures have a finite rmsd.""" import numpy as np from asr.database.rmsd import get_rmsd pbc_c = test_material.get_pbc() repeat = np.array([3, 3, 3], int) repeat[~pbc_c] = 1 rattled_atoms = test_material.repeat(repeat) rattle_atoms(rattled_atoms, 0.01, seed=42) rmsd = get_rmsd(test_material, rattled_atoms) assert rmsd > 0.0, (test_material.repeat(repeat).get_scaled_positions() - rattled_atoms.get_scaled_positions())
[docs] @pytest.mark.ci @pytest.mark.parametrize('atom_index', [0, -1]) @pytest.mark.parametrize('axis', [0, 1, 2]) @pytest.mark.parametrize('displacement', [0.01, 0.02]) def test_database_rmsd_move_one_atom(test_material, atom_index, axis, displacement): """Test that rmsd=0 when comparing rotated structures.""" import numpy as np from asr.database.rmsd import get_rmsd pbc_c = test_material.get_pbc() repeat = np.array([1, 1, 1], int) repeat[np.argwhere(pbc_c)[0, 0]] = 2 atoms = test_material.repeat(repeat) translations_av = np.zeros((len(atoms), 3), float) translations_av[atom_index, axis] = displacement atoms.translate(translations_av) rmsd = get_rmsd(test_material, atoms) # Logic of the formula below: # The reference stucture is displaced according to the average displacement # of all atoms hence the second term. disp = displacement - displacement / len(atoms) assert rmsd == pytest.approx(disp)
[docs] @pytest.mark.ci @pytest.mark.parametrize('atoms1,atoms2', [ ( Atoms(symbols='Co2S2', pbc=[True, True, False], cell=[[3.5790788191969725, -1.1842760125086163e-20, 0.0], [-1.7895394075540594, 3.10048672285293, 0.0], [2.3583795244967227e-18, 0.0, 18.85580293064]], scaled_positions=[[0, 0, 0.56], [1 / 3, 2 / 3, 0.44], [1 / 3, 2 / 3, 0.40], [0, 0, 0.60]]), Atoms(symbols='Co2S2', pbc=[True, True, False], cell=[[3.5790788191969725, -1.1842760125086163e-20, 0.0], [-1.7895394075540594, 3.10048672285293, 0.0], [2.3583795244967227e-18, 0.0, 18.85580293064]], scaled_positions=[[0, 0, 0.56], [0, 0, 0.44], [1 / 3, 2 / 3, 0.40], [2 / 3, 1 / 3, 0.60]]) ) ]) def test_database_rmsd_not_equal(atoms1, atoms2): """Test some explicit cases that have previously posed a problem.""" from asr.database.rmsd import get_rmsd rmsd = get_rmsd(atoms1, atoms2) assert not rmsd < 0.5