wmaee.codes.vasp.vasp_ml

Collection of helper functions for handling ML functionality of VASP.

  1"""
  2Collection of helper functions for handling ML functionality of VASP.
  3"""
  4
  5from typing import Callable
  6from typing import List, Union
  7
  8def ML_ABN_concat(files: List[str], output: str = 'ML_AB', overwrite: bool = False, verbose: bool = True) -> None:
  9    """
 10    Concatenates a list of `ML_AB` or `ML_ABN` files. This function merges headers and structures
 11    from multiple files into a single output file.
 12
 13    Parameters
 14    ----------
 15    files : List[str]
 16        List of paths to input files that should be concatenated.
 17    output : str, optional
 18        Path for the output file, by default 'ML_AB'.
 19    overwrite : bool, optional
 20        Whether to overwrite the output file if it already exists, by default False.
 21    verbose : bool, optional
 22        If True, prints details of the merging process, by default True.
 23
 24    Raises
 25    ------
 26    FileExistsError
 27        If the output file exists and overwrite is set to False.
 28
 29    Notes
 30    -----
 31    This function currently assumes that all input files contain the same chemistry.
 32    """
 33    
 34    from os.path import exists
 35    from os import remove, rename
 36
 37    if exists(output) and not overwrite:
 38        raise FileExistsError(f'Output file {output} exists in the current folder and overwrite is not permitted.')
 39    else:
 40        out = open(output+'.tmp_wmaee_tmp', 'w')
 41
 42    s = 0
 43    sp = list()
 44    max_types = 0
 45    max_atoms = 0
 46    max_species = 0
 47    masses = dict()
 48    for i, f in enumerate(files):
 49        if verbose:
 50            print(f'reading structures from {f}')
 51        with open(f, 'r') as src:
 52            # deal with header: collect merged information
 53            l = src.readline();
 54            while l.strip().split()[0] != 'Configuration':                
 55                if 'The atom types' in l.strip():
 56                    l = src.readline();
 57                    l = src.readline();
 58                    this_sp = l.strip().split()
 59                    sp += this_sp
 60                    sp = list(set(sp))
 61                    if verbose:
 62                        print(f'  found species: {sp}')
 63                if 'number of atom type' in l.strip():
 64                    l = src.readline();
 65                    l = src.readline();
 66                    max_at = int(l.strip())
 67                    if verbose:
 68                        print(f'  maximum number of atom types: {max_at}')
 69                    max_types = max(max_types, max_at)                        
 70                if 'atoms per system' in l.strip():
 71                    l = src.readline();
 72                    l = src.readline();
 73                    max_at = int(l.strip())
 74                    if verbose:
 75                        print(f'  maximum number of atoms: {max_at}')
 76                    max_atoms = max(max_atoms, max_at)
 77                if 'atoms per atom type' in l.strip():
 78                    l = src.readline();
 79                    l = src.readline();
 80                    max_at = int(l.strip())
 81                    if verbose:
 82                        print(f'  maximum number of atoms per type: {max_at}')
 83                    max_species = max(max_species, max_at)
 84                if 'Atomic mass' in l.strip():
 85                    l = src.readline();
 86                    m = ''
 87                    l = src.readline();
 88                    while '*********************' not in l.strip():
 89                        m += l
 90                        l = src.readline();
 91                    for X, M in zip(this_sp, m.strip().split()):
 92                        masses[X] = float(M)
 93                        if verbose:
 94                            print(f'  mass of {X}: {float(M)} a.u.')              
 95                l = src.readline();
 96            # done with headers, now go configuration by configuration            
 97            out.writelines("**************************************************\n")
 98            while l:
 99                s += 1
100                out.writelines(f"     Configuration num.{s:7d}\n")
101                if verbose:
102                    print(f"     Configuration num.{s:7d}")
103                l = src.readline()
104                while l and l.strip().split()[0] != 'Configuration':
105                    out.writelines(l)
106                    l = src.readline()
107    out.close()
108    
109    sp = list(set(sp))
110
111    # write header    
112    with open(output, 'w') as out:        
113        _SEP_star  = "**************************************************\n"
114        _SEP_minus = "--------------------------------------------------\n"
115        _SEP_equal = "==================================================\n"
116        out.writelines(' 1.0 Version\n'+_SEP_star)
117        out.writelines('     The number of configurations\n'+_SEP_minus+f'{s:11d}\n'+_SEP_star)
118        out.writelines('     The maximum number of atom type\n'+_SEP_minus+f'{len(sp):8d}\n'+_SEP_star)
119        out.writelines('     The atom types in the data file\n'+_SEP_minus+'     ')
120        for i, el in enumerate(sp):
121            out.writelines(el.ljust(3))
122            if (i%3 == 2) and i<len(sp)-1:
123                out.writelines('\n     ')
124        out.writelines('\n'+_SEP_star)
125        out.writelines('     The maximum number of atoms per system\n'+_SEP_minus+f'{max_atoms:8d}\n'+_SEP_star)
126        out.writelines('     The maximum number of atoms per atom type\n'+_SEP_minus+f'{max_species:8d}\n'+_SEP_star)
127        out.writelines('     Reference atomic energy (eV)\n'+_SEP_minus)
128        for i, _ in enumerate(sp):
129            out.writelines('   '+f'{0:18.16f}'+'     ') # dummy energy
130            if (i%3 == 2) and i<len(sp)-1:
131                out.writelines('\n')
132        out.writelines('\n'+_SEP_star)
133        out.writelines('     Atomic mass\n'+_SEP_minus)        
134        for i, X in enumerate(sp):
135            out.writelines('   '+f'{masses[X]:18.16f}'+'     ')
136            if (i%3 == 2) and i<len(sp)-1:
137                out.writelines('\n')
138        out.writelines('\n'+_SEP_star)
139        out.writelines('     The numbers of basis sets per atom type\n'+_SEP_minus+'    ')
140        for i, _ in enumerate(sp):
141            out.writelines('     1')
142            if (i%3 == 2) and i<len(sp)-1:
143                out.writelines('\n')
144        out.writelines('\n')
145        for i, el in enumerate(sp):
146            out.writelines(_SEP_star+f'     Basis set for {el}\n'+_SEP_minus+'          1      1\n')
147        
148        with open(output+'.tmp_wmaee_tmp', 'r') as src:
149            l = src.readline()
150            while l:
151                out.writelines(l)
152                l = src.readline()
153              
154    remove(output+'.tmp_wmaee_tmp')
155    
156    
157    
158def _read_table_pattern(filename: str, header_pattern: str, row_pattern: str,
159        footer_pattern: str, postprocess: Callable = str, first_one_only: bool = False,
160    ) -> list:
161    """
162    Parse table-like data from a file. A table comprises three parts: header,
163    main body, and footer. All the data matching the "row pattern" in the main body
164    will be returned.
165
166    Parameters
167    ----------
168    filename : str
169        The name of the file to read the table from.
170    header_pattern : str
171        The regular expression pattern that matches the table header.
172        This pattern should match all the text immediately before the main body of the table.
173        For multiple sections table, match the text until the section of interest.
174        MULTILINE and DOTALL options are enforced, so the "." meta-character will also match "\n" in this section.
175    row_pattern : str
176        The regular expression that matches a single line in the table. Capture interested fields using regular expression groups.
177    footer_pattern : str
178        The regular expression that matches the end of the table, e.g., a long dash line.
179    postprocess : Callable, optional
180        A post-processing function to convert all matches. Defaults to str, i.e., no change.
181    first_one_only : bool, optional
182        Only the first occurrence of the table will be parsed and the parsing procedure will stop. 
183        The enclosing list will be removed, i.e., only a single table will be returned. Incompatible with last_one_only.
184
185    Returns
186    -------
187    List[Union[List, Dict]]
188        A list of tables. A table is a list of rows. A row is either a list of attribute values (if capturing groups are defined without names in row_pattern), 
189        or a dictionary (if named capturing groups are defined by row_pattern).
190
191    Notes
192    -----
193    This function is adapted from pymatgen.io.vasp.outputs.py.
194    """
195    
196    from monty.io import zopen
197    import re
198    
199    with zopen(filename, mode='rt') as file:
200        text = file.read()
201    table_pattern_text = header_pattern + r"\s*^(?P<table_body>(?:\s+" + row_pattern + r")+)\s+" + footer_pattern
202    table_pattern = re.compile(table_pattern_text, re.MULTILINE | re.DOTALL)
203    rp = re.compile(row_pattern)
204    tables: list[list] = []
205    for mt in table_pattern.finditer(text):
206        table_body_text = mt.group("table_body")
207        table_contents = []
208        for line in table_body_text.split("\n"):
209            ml = rp.search(line)
210            # Skip empty lines
211            if not ml:
212                continue
213            d = ml.groupdict()
214            if len(d) > 0:
215                processed_line: dict | list = {k: postprocess(v) for k, v in d.items()}
216            else:
217                processed_line = [postprocess(v) for v in ml.groups()]
218            table_contents.append(processed_line)
219        tables.append(table_contents)
220        if first_one_only:
221            break
222    return tables
223    
224    
225
226def generate_ML_AB(input: str = 'OUTCAR', input_type: str = 'OUTCAR', output: str = 'ML_AB', 
227                   overwrite: bool = False, verbose: bool = True, system_name: str = 'parsed AIMD') -> None:   
228    """
229    Generate ML_AB formatted output from VASP calculation outputs.
230
231    Parameters
232    ----------
233    input : str, optional
234        Path to the input file, by default 'OUTCAR'
235    input_type : str, optional
236        Type of the input file, either 'OUTCAR' or 'vasprun', by default 'OUTCAR'
237    output : str, optional
238        Path to the output file, by default 'ML_AB'
239    overwrite : bool, optional
240        Flag to overwrite the existing output file, by default False
241    verbose : bool, optional
242        Flag to print detailed output, by default True
243    system_name : str, optional
244        Name of the system, by default 'parsed AIMD'
245
246    Raises
247    ------
248    Exception
249        If input_type is not 'OUTCAR' or 'vasprun'
250    FileExistsError
251        If the input file does not exist
252    """
253    
254    from os.path import exists
255    from monty.re import regrep
256    import numpy as np
257    
258     # Validate input_type
259    if input_type not in ['OUTCAR', 'vasprun']:
260        raise Exception(f'Unknown input file type, allowed values are `OUTCAR` or `vasprun`')
261
262    # Check if input file exists
263    if not exists(input):
264        raise FileExistsError(f'Input file {input} doesn\'t exist in the current folder.')
265
266    # Read species, numbers, and masses from the input file
267    if input_type=='vasprun':
268        chemistry = _read_table_pattern(
269            filename=input,            
270            header_pattern=r"\s*<field type=\"string\">pseudopotential<\/field>\s*<set>",
271            row_pattern=r"\s*<rc><c>\s*(\d*)<\/c><c>\s*(\w*)\s*<\/c><c>\s*(\d*.\d*)<\/c><c>.*",
272            footer_pattern=r"\s*<\/set>",
273            first_one_only=True
274        )[0]
275        sp = list()
276        num_sp = list()
277        masses = list()
278        for n, s, m in chemistry:
279            sp.append(s)
280            num_sp.append(int(n))
281            masses.append(float(m))
282    if input_type=='OUTCAR':
283        sp = regrep(
284            input, 
285            patterns=dict(species='TITEL\s+=\s\w+\s(\w+)')
286            )
287        sp = [x[0][0] for x in sp['species']]
288    if verbose:
289        print('found species: ', sp)
290    if input_type=='OUTCAR':
291        num_sp = regrep(
292            input,
293            patterns=dict(num_species='ions per type\s+=\s+'+''.join(['(\d+)\s*']*len(sp))),
294            terminate_on_match=True,
295            postprocess=lambda x: int(x)
296            )
297        # print('ions per type\s+=\s+'+''.join(['(\d+)\s*']*len(sp)))
298        num_sp = num_sp['num_species'][0][0]
299    if verbose:
300        print('num species: ', num_sp)
301    if input_type=='OUTCAR':
302        masses = regrep(
303            input,
304            patterns=dict(masses='POMASS\s+=\s+'+''.join(['(\d+.\d+)\s*']*len(sp))),
305            terminate_on_match=True,
306            postprocess=lambda x: float(x)
307            )
308        masses = masses['masses'][0][0]
309    if verbose:
310        print('masses: ', masses)
311        
312    # Read lattice vectors from the input file
313    if input_type == 'OUTCAR':
314        lattices = _read_table_pattern(
315            filename=input,
316            header_pattern=r"\sdirect lattice vectors\s+reciprocal lattice vectors",
317            row_pattern=r"\s+([+-]?\d+\.\d+)\s+([+-]?\d+\.\d+)\s+([+-]?\d+\.\d+)\s+[+-]?\d+\.\d+\s+[+-]?\d+\.\d+\s+[+-]?\d+\.\d+",
318            footer_pattern=r"\n\s+length of vectors",
319            postprocess=lambda x: float(x)
320        )[2:]  # Skip initial lattice
321    else:
322        lattices = _read_table_pattern(
323            filename=input,
324            header_pattern=r"\s*<varray name=\"basis\"\s*>",
325            row_pattern=r"\s*<v>\s*([+-]?\d+\.\d+)\s+([+-]?\d+\.\d+)\s+([+-]?\d+\.\d+)\s<\/v>\s*",
326            footer_pattern=r"\s*<\/varray>",
327            postprocess=lambda x: float(x)
328        )[2:]  # Skip initial lattice
329    if len(lattices[-1]) != 3:
330        lattices = lattices[:-1]  # Remove invalid lattice
331    if verbose:
332        print('found lattices: ', len(lattices))
333        
334    # Read energies and stresses from the input fi
335    if input_type=='OUTCAR':
336        energies_stresses = regrep(
337            input,
338            patterns=dict(
339                energies='free  energy\s+TOTEN\s+=\s+([+-]?\d+.\d+)\seV',
340                stresses='in kB\s+([+-]?\d+.\d+)\s+([+-]?\d+.\d+)\s+([+-]?\d+.\d+)\s+([+-]?\d+.\d+)\s+([+-]?\d+.\d+)\s+([+-]?\d+.\d+)'
341                ),
342            terminate_on_match=False,
343            postprocess=lambda x: float(x)
344            )
345        energies = [e[0][0] for e in energies_stresses['energies']]
346        stresses = [s[0] for s in energies_stresses['stresses']]
347    else:
348        energies = _read_table_pattern(
349            filename=input,
350            header_pattern=r"\s*<energy>",
351            row_pattern=r"\s*<i name=\"\w*\">\s*([+-]?\d+\.\d+)\s+<\/i>\s*",
352            footer_pattern=r"\s*<i name=\"kinetic\">",
353            postprocess=lambda x: float(x))        
354        for i, e in enumerate(energies):
355            energies[i] = e[0][0]
356        stresses = _read_table_pattern(
357            filename=input,
358            header_pattern=r"\s*<varray name=\"stress\"\s*>",
359            row_pattern=r"\s*<v>\s*([+-]?\d+\.\d+)\s+([+-]?\d+\.\d+)\s+([+-]?\d+\.\d+)\s<\/v>\s*",
360            footer_pattern=r"\s*<\/varray>",
361            postprocess=lambda x: float(x))
362        for i, s in enumerate(stresses):
363            stresses[i] = [s[0][0], s[1][1], s[2][2], s[0][1], s[0][2], s[1][2]]
364    if verbose:
365        print('found energies: ', len(energies))
366        print('found stresses: ', len(stresses))
367        
368    # Read positions and forces from the input file
369    if input_type=='OUTCAR':
370        positions_and_forces = _read_table_pattern(
371            filename=input,
372            header_pattern=r"\sPOSITION\s+TOTAL-FORCE \(eV/Angst\)\n\s-+",
373            row_pattern=r"\s+([+-]?\d+\.\d+)\s+([+-]?\d+\.\d+)\s+([+-]?\d+\.\d+)\s+([+-]?\d+\.\d+)\s+([+-]?\d+\.\d+)\s+([+-]?\d+\.\d+)",
374            footer_pattern=r"\s--+",
375            postprocess=lambda x: float(x))
376        if len(positions_and_forces[-1]) != sum(num_sp):
377            # invalid number of positions and forces of the last structure, remove it
378            positions_and_forces = positions_and_forces[:-1]
379        positions = [[x[:3] for x in ionic_step] for ionic_step in positions_and_forces]
380        forces = [[x[3:] for x in ionic_step] for ionic_step in positions_and_forces]
381    else:
382        positions = _read_table_pattern(
383            filename=input,
384            header_pattern=r"\s*<varray name=\"positions\"\s*>",
385            row_pattern=r"\s*<v>\s*([+-]?\d+\.\d+)\s+([+-]?\d+\.\d+)\s+([+-]?\d+\.\d+)\s<\/v>\s*",
386            footer_pattern=r"\s*<\/varray>",
387            postprocess=lambda x: float(x))[2:] # the first lattices are printed during initialization, not after ionic steps
388        forces = _read_table_pattern(
389            filename=input,
390            header_pattern=r"\s*<varray name=\"forces\"\s*>",
391            row_pattern=r"\s*<v>\s*([+-]?\d+\.\d+)\s+([+-]?\d+\.\d+)\s+([+-]?\d+\.\d+)\s<\/v>\s*",
392            footer_pattern=r"\s*<\/varray>",
393            postprocess=lambda x: float(x))
394    if verbose:
395        print('found positions: ', len(positions))
396        print('found forces: ', len(forces))
397    
398    # Take the minimum number of data points across all lists    
399    take_only = min(len(lattices), len(positions), len(energies), len(stresses))
400    lattices = lattices[:take_only]
401    energies = energies[:take_only]
402    positions = positions[:take_only]
403    
404    # Write ML_AB file
405    out = open(output, 'w')
406    _SEP_star  = "**************************************************\n"
407    _SEP_minus = "--------------------------------------------------\n"
408    _SEP_equal = "==================================================\n"
409    out.writelines(' 1.0 Version\n'+_SEP_star)
410    out.writelines('     The number of configurations\n'+_SEP_minus+f'{len(lattices):11d}\n'+_SEP_star)
411    out.writelines('     The maximum number of atom type\n'+_SEP_minus+f'{len(num_sp):8d}\n'+_SEP_star)
412    out.writelines('     The atom types in the data file\n'+_SEP_minus+'     ')
413    for i, el in enumerate(sp):
414        out.writelines(el.ljust(3))
415        if (i%3 == 2) and i<len(sp)-1:
416            out.writelines('\n     ')
417    out.writelines('\n'+_SEP_star)
418    out.writelines('     The maximum number of atoms per system\n'+_SEP_minus+f'{sum(num_sp):8d}\n'+_SEP_star)
419    out.writelines('     The maximum number of atoms per atom type\n'+_SEP_minus+f'{max(num_sp):8d}\n'+_SEP_star)
420    out.writelines('     Reference atomic energy (eV)\n'+_SEP_minus)
421    for i, _ in enumerate(sp):
422        out.writelines('   '+f'{0:18.16f}'+'     ') # dummy energy
423        if (i%3 == 2) and i<len(sp)-1:
424            out.writelines('\n')
425    out.writelines('\n'+_SEP_star)
426    out.writelines('     Atomic mass\n'+_SEP_minus)
427    for i, m in enumerate(masses):
428        out.writelines('   '+f'{m:18.16f}'+'     ')
429        if (i%3 == 2) and i<len(sp)-1:
430            out.writelines('\n')
431    out.writelines('\n'+_SEP_star)
432    out.writelines('     The numbers of basis sets per atom type\n'+_SEP_minus+'    ')
433    for i, _ in enumerate(sp):
434        out.writelines('     1')
435        if (i%3 == 2) and i<len(sp)-1:
436            out.writelines('\n')
437    out.writelines('\n')
438    for i, el in enumerate(sp):
439        out.writelines(_SEP_star+f'     Basis set for {el}\n'+_SEP_minus+'          1      1\n')
440    for i in range(len(lattices)):
441        out.writelines(_SEP_star)
442        out.writelines(f'     Configuration num.{i+1:7d}\n'+_SEP_equal)
443        out.writelines('     System name\n'+_SEP_minus+'     '+system_name+'\n'+_SEP_equal)
444        out.writelines('     The number of atom types\n'+_SEP_minus+f'{len(num_sp):8d}\n'+_SEP_equal)
445        out.writelines('     The number of atoms\n'+_SEP_minus+f'{sum(num_sp):8d}\n'+_SEP_star)
446        out.writelines('     Atom types and atom numbers\n'+_SEP_minus)
447        for el, n in zip(sp, num_sp):
448            out.writelines('     '+el.ljust(3)+f'{n:6d}\n')
449        out.writelines(_SEP_equal)
450        out.writelines('     Primitive lattice vectors (ang.)\n'+_SEP_minus)
451        lattice = lattices[i]
452        for l in lattice:
453            out.writelines('  '+'  '.join([f'{x:18.16f}' for x in l])+'\n')
454        out.writelines(_SEP_equal)
455        out.writelines('     Atomic positions (ang.)\n'+_SEP_minus)
456        if input_type == 'OUTCAR':
457            for l in positions[i]:
458                out.writelines('  '+'  '.join([f'{x:18.16f}' for x in l])+'\n')
459        else:
460            # need to convert to Cartesian coordinates for vasprun.xml inputs
461            for l in positions[i]:
462                out.writelines('  '+'  '.join([f'{x:18.16f}' for x in np.matmul(l, np.array(lattice))])+'\n')
463        out.writelines(_SEP_equal)
464        out.writelines('     Total energy (eV)\n'+_SEP_minus+f'  {energies[i]}\n'+_SEP_equal)
465        out.writelines('     Forces (eV ang.^-1)\n'+_SEP_minus)
466        for l in forces[i]:
467            out.writelines('  '+'  '.join([f'{x:18.16f}' for x in l])+'\n')
468        out.writelines(_SEP_equal)
469        out.writelines('     Stress (kbar)\n'+_SEP_minus+'     XX YY ZZ\n'+_SEP_minus)
470        out.writelines('  '+'  '.join([f'{x:18.16f}' for x in stresses[i][:3]])+'\n'+_SEP_minus)
471        out.writelines('     XY YZ ZX\n'+_SEP_minus)
472        out.writelines('  '+'  '.join([f'{x:18.16f}' for x in stresses[i][3:]])+'\n')
def ML_ABN_concat( files: List[str], output: str = 'ML_AB', overwrite: bool = False, verbose: bool = True) -> None:
  9def ML_ABN_concat(files: List[str], output: str = 'ML_AB', overwrite: bool = False, verbose: bool = True) -> None:
 10    """
 11    Concatenates a list of `ML_AB` or `ML_ABN` files. This function merges headers and structures
 12    from multiple files into a single output file.
 13
 14    Parameters
 15    ----------
 16    files : List[str]
 17        List of paths to input files that should be concatenated.
 18    output : str, optional
 19        Path for the output file, by default 'ML_AB'.
 20    overwrite : bool, optional
 21        Whether to overwrite the output file if it already exists, by default False.
 22    verbose : bool, optional
 23        If True, prints details of the merging process, by default True.
 24
 25    Raises
 26    ------
 27    FileExistsError
 28        If the output file exists and overwrite is set to False.
 29
 30    Notes
 31    -----
 32    This function currently assumes that all input files contain the same chemistry.
 33    """
 34    
 35    from os.path import exists
 36    from os import remove, rename
 37
 38    if exists(output) and not overwrite:
 39        raise FileExistsError(f'Output file {output} exists in the current folder and overwrite is not permitted.')
 40    else:
 41        out = open(output+'.tmp_wmaee_tmp', 'w')
 42
 43    s = 0
 44    sp = list()
 45    max_types = 0
 46    max_atoms = 0
 47    max_species = 0
 48    masses = dict()
 49    for i, f in enumerate(files):
 50        if verbose:
 51            print(f'reading structures from {f}')
 52        with open(f, 'r') as src:
 53            # deal with header: collect merged information
 54            l = src.readline();
 55            while l.strip().split()[0] != 'Configuration':                
 56                if 'The atom types' in l.strip():
 57                    l = src.readline();
 58                    l = src.readline();
 59                    this_sp = l.strip().split()
 60                    sp += this_sp
 61                    sp = list(set(sp))
 62                    if verbose:
 63                        print(f'  found species: {sp}')
 64                if 'number of atom type' in l.strip():
 65                    l = src.readline();
 66                    l = src.readline();
 67                    max_at = int(l.strip())
 68                    if verbose:
 69                        print(f'  maximum number of atom types: {max_at}')
 70                    max_types = max(max_types, max_at)                        
 71                if 'atoms per system' in l.strip():
 72                    l = src.readline();
 73                    l = src.readline();
 74                    max_at = int(l.strip())
 75                    if verbose:
 76                        print(f'  maximum number of atoms: {max_at}')
 77                    max_atoms = max(max_atoms, max_at)
 78                if 'atoms per atom type' in l.strip():
 79                    l = src.readline();
 80                    l = src.readline();
 81                    max_at = int(l.strip())
 82                    if verbose:
 83                        print(f'  maximum number of atoms per type: {max_at}')
 84                    max_species = max(max_species, max_at)
 85                if 'Atomic mass' in l.strip():
 86                    l = src.readline();
 87                    m = ''
 88                    l = src.readline();
 89                    while '*********************' not in l.strip():
 90                        m += l
 91                        l = src.readline();
 92                    for X, M in zip(this_sp, m.strip().split()):
 93                        masses[X] = float(M)
 94                        if verbose:
 95                            print(f'  mass of {X}: {float(M)} a.u.')              
 96                l = src.readline();
 97            # done with headers, now go configuration by configuration            
 98            out.writelines("**************************************************\n")
 99            while l:
100                s += 1
101                out.writelines(f"     Configuration num.{s:7d}\n")
102                if verbose:
103                    print(f"     Configuration num.{s:7d}")
104                l = src.readline()
105                while l and l.strip().split()[0] != 'Configuration':
106                    out.writelines(l)
107                    l = src.readline()
108    out.close()
109    
110    sp = list(set(sp))
111
112    # write header    
113    with open(output, 'w') as out:        
114        _SEP_star  = "**************************************************\n"
115        _SEP_minus = "--------------------------------------------------\n"
116        _SEP_equal = "==================================================\n"
117        out.writelines(' 1.0 Version\n'+_SEP_star)
118        out.writelines('     The number of configurations\n'+_SEP_minus+f'{s:11d}\n'+_SEP_star)
119        out.writelines('     The maximum number of atom type\n'+_SEP_minus+f'{len(sp):8d}\n'+_SEP_star)
120        out.writelines('     The atom types in the data file\n'+_SEP_minus+'     ')
121        for i, el in enumerate(sp):
122            out.writelines(el.ljust(3))
123            if (i%3 == 2) and i<len(sp)-1:
124                out.writelines('\n     ')
125        out.writelines('\n'+_SEP_star)
126        out.writelines('     The maximum number of atoms per system\n'+_SEP_minus+f'{max_atoms:8d}\n'+_SEP_star)
127        out.writelines('     The maximum number of atoms per atom type\n'+_SEP_minus+f'{max_species:8d}\n'+_SEP_star)
128        out.writelines('     Reference atomic energy (eV)\n'+_SEP_minus)
129        for i, _ in enumerate(sp):
130            out.writelines('   '+f'{0:18.16f}'+'     ') # dummy energy
131            if (i%3 == 2) and i<len(sp)-1:
132                out.writelines('\n')
133        out.writelines('\n'+_SEP_star)
134        out.writelines('     Atomic mass\n'+_SEP_minus)        
135        for i, X in enumerate(sp):
136            out.writelines('   '+f'{masses[X]:18.16f}'+'     ')
137            if (i%3 == 2) and i<len(sp)-1:
138                out.writelines('\n')
139        out.writelines('\n'+_SEP_star)
140        out.writelines('     The numbers of basis sets per atom type\n'+_SEP_minus+'    ')
141        for i, _ in enumerate(sp):
142            out.writelines('     1')
143            if (i%3 == 2) and i<len(sp)-1:
144                out.writelines('\n')
145        out.writelines('\n')
146        for i, el in enumerate(sp):
147            out.writelines(_SEP_star+f'     Basis set for {el}\n'+_SEP_minus+'          1      1\n')
148        
149        with open(output+'.tmp_wmaee_tmp', 'r') as src:
150            l = src.readline()
151            while l:
152                out.writelines(l)
153                l = src.readline()
154              
155    remove(output+'.tmp_wmaee_tmp')

Concatenates a list of ML_AB or ML_ABN files. This function merges headers and structures from multiple files into a single output file.

Parameters
  • files (List[str]): List of paths to input files that should be concatenated.
  • output (str, optional): Path for the output file, by default 'ML_AB'.
  • overwrite (bool, optional): Whether to overwrite the output file if it already exists, by default False.
  • verbose (bool, optional): If True, prints details of the merging process, by default True.
Raises
  • FileExistsError: If the output file exists and overwrite is set to False.
Notes

This function currently assumes that all input files contain the same chemistry.

def generate_ML_AB( input: str = 'OUTCAR', input_type: str = 'OUTCAR', output: str = 'ML_AB', overwrite: bool = False, verbose: bool = True, system_name: str = 'parsed AIMD') -> None:
227def generate_ML_AB(input: str = 'OUTCAR', input_type: str = 'OUTCAR', output: str = 'ML_AB', 
228                   overwrite: bool = False, verbose: bool = True, system_name: str = 'parsed AIMD') -> None:   
229    """
230    Generate ML_AB formatted output from VASP calculation outputs.
231
232    Parameters
233    ----------
234    input : str, optional
235        Path to the input file, by default 'OUTCAR'
236    input_type : str, optional
237        Type of the input file, either 'OUTCAR' or 'vasprun', by default 'OUTCAR'
238    output : str, optional
239        Path to the output file, by default 'ML_AB'
240    overwrite : bool, optional
241        Flag to overwrite the existing output file, by default False
242    verbose : bool, optional
243        Flag to print detailed output, by default True
244    system_name : str, optional
245        Name of the system, by default 'parsed AIMD'
246
247    Raises
248    ------
249    Exception
250        If input_type is not 'OUTCAR' or 'vasprun'
251    FileExistsError
252        If the input file does not exist
253    """
254    
255    from os.path import exists
256    from monty.re import regrep
257    import numpy as np
258    
259     # Validate input_type
260    if input_type not in ['OUTCAR', 'vasprun']:
261        raise Exception(f'Unknown input file type, allowed values are `OUTCAR` or `vasprun`')
262
263    # Check if input file exists
264    if not exists(input):
265        raise FileExistsError(f'Input file {input} doesn\'t exist in the current folder.')
266
267    # Read species, numbers, and masses from the input file
268    if input_type=='vasprun':
269        chemistry = _read_table_pattern(
270            filename=input,            
271            header_pattern=r"\s*<field type=\"string\">pseudopotential<\/field>\s*<set>",
272            row_pattern=r"\s*<rc><c>\s*(\d*)<\/c><c>\s*(\w*)\s*<\/c><c>\s*(\d*.\d*)<\/c><c>.*",
273            footer_pattern=r"\s*<\/set>",
274            first_one_only=True
275        )[0]
276        sp = list()
277        num_sp = list()
278        masses = list()
279        for n, s, m in chemistry:
280            sp.append(s)
281            num_sp.append(int(n))
282            masses.append(float(m))
283    if input_type=='OUTCAR':
284        sp = regrep(
285            input, 
286            patterns=dict(species='TITEL\s+=\s\w+\s(\w+)')
287            )
288        sp = [x[0][0] for x in sp['species']]
289    if verbose:
290        print('found species: ', sp)
291    if input_type=='OUTCAR':
292        num_sp = regrep(
293            input,
294            patterns=dict(num_species='ions per type\s+=\s+'+''.join(['(\d+)\s*']*len(sp))),
295            terminate_on_match=True,
296            postprocess=lambda x: int(x)
297            )
298        # print('ions per type\s+=\s+'+''.join(['(\d+)\s*']*len(sp)))
299        num_sp = num_sp['num_species'][0][0]
300    if verbose:
301        print('num species: ', num_sp)
302    if input_type=='OUTCAR':
303        masses = regrep(
304            input,
305            patterns=dict(masses='POMASS\s+=\s+'+''.join(['(\d+.\d+)\s*']*len(sp))),
306            terminate_on_match=True,
307            postprocess=lambda x: float(x)
308            )
309        masses = masses['masses'][0][0]
310    if verbose:
311        print('masses: ', masses)
312        
313    # Read lattice vectors from the input file
314    if input_type == 'OUTCAR':
315        lattices = _read_table_pattern(
316            filename=input,
317            header_pattern=r"\sdirect lattice vectors\s+reciprocal lattice vectors",
318            row_pattern=r"\s+([+-]?\d+\.\d+)\s+([+-]?\d+\.\d+)\s+([+-]?\d+\.\d+)\s+[+-]?\d+\.\d+\s+[+-]?\d+\.\d+\s+[+-]?\d+\.\d+",
319            footer_pattern=r"\n\s+length of vectors",
320            postprocess=lambda x: float(x)
321        )[2:]  # Skip initial lattice
322    else:
323        lattices = _read_table_pattern(
324            filename=input,
325            header_pattern=r"\s*<varray name=\"basis\"\s*>",
326            row_pattern=r"\s*<v>\s*([+-]?\d+\.\d+)\s+([+-]?\d+\.\d+)\s+([+-]?\d+\.\d+)\s<\/v>\s*",
327            footer_pattern=r"\s*<\/varray>",
328            postprocess=lambda x: float(x)
329        )[2:]  # Skip initial lattice
330    if len(lattices[-1]) != 3:
331        lattices = lattices[:-1]  # Remove invalid lattice
332    if verbose:
333        print('found lattices: ', len(lattices))
334        
335    # Read energies and stresses from the input fi
336    if input_type=='OUTCAR':
337        energies_stresses = regrep(
338            input,
339            patterns=dict(
340                energies='free  energy\s+TOTEN\s+=\s+([+-]?\d+.\d+)\seV',
341                stresses='in kB\s+([+-]?\d+.\d+)\s+([+-]?\d+.\d+)\s+([+-]?\d+.\d+)\s+([+-]?\d+.\d+)\s+([+-]?\d+.\d+)\s+([+-]?\d+.\d+)'
342                ),
343            terminate_on_match=False,
344            postprocess=lambda x: float(x)
345            )
346        energies = [e[0][0] for e in energies_stresses['energies']]
347        stresses = [s[0] for s in energies_stresses['stresses']]
348    else:
349        energies = _read_table_pattern(
350            filename=input,
351            header_pattern=r"\s*<energy>",
352            row_pattern=r"\s*<i name=\"\w*\">\s*([+-]?\d+\.\d+)\s+<\/i>\s*",
353            footer_pattern=r"\s*<i name=\"kinetic\">",
354            postprocess=lambda x: float(x))        
355        for i, e in enumerate(energies):
356            energies[i] = e[0][0]
357        stresses = _read_table_pattern(
358            filename=input,
359            header_pattern=r"\s*<varray name=\"stress\"\s*>",
360            row_pattern=r"\s*<v>\s*([+-]?\d+\.\d+)\s+([+-]?\d+\.\d+)\s+([+-]?\d+\.\d+)\s<\/v>\s*",
361            footer_pattern=r"\s*<\/varray>",
362            postprocess=lambda x: float(x))
363        for i, s in enumerate(stresses):
364            stresses[i] = [s[0][0], s[1][1], s[2][2], s[0][1], s[0][2], s[1][2]]
365    if verbose:
366        print('found energies: ', len(energies))
367        print('found stresses: ', len(stresses))
368        
369    # Read positions and forces from the input file
370    if input_type=='OUTCAR':
371        positions_and_forces = _read_table_pattern(
372            filename=input,
373            header_pattern=r"\sPOSITION\s+TOTAL-FORCE \(eV/Angst\)\n\s-+",
374            row_pattern=r"\s+([+-]?\d+\.\d+)\s+([+-]?\d+\.\d+)\s+([+-]?\d+\.\d+)\s+([+-]?\d+\.\d+)\s+([+-]?\d+\.\d+)\s+([+-]?\d+\.\d+)",
375            footer_pattern=r"\s--+",
376            postprocess=lambda x: float(x))
377        if len(positions_and_forces[-1]) != sum(num_sp):
378            # invalid number of positions and forces of the last structure, remove it
379            positions_and_forces = positions_and_forces[:-1]
380        positions = [[x[:3] for x in ionic_step] for ionic_step in positions_and_forces]
381        forces = [[x[3:] for x in ionic_step] for ionic_step in positions_and_forces]
382    else:
383        positions = _read_table_pattern(
384            filename=input,
385            header_pattern=r"\s*<varray name=\"positions\"\s*>",
386            row_pattern=r"\s*<v>\s*([+-]?\d+\.\d+)\s+([+-]?\d+\.\d+)\s+([+-]?\d+\.\d+)\s<\/v>\s*",
387            footer_pattern=r"\s*<\/varray>",
388            postprocess=lambda x: float(x))[2:] # the first lattices are printed during initialization, not after ionic steps
389        forces = _read_table_pattern(
390            filename=input,
391            header_pattern=r"\s*<varray name=\"forces\"\s*>",
392            row_pattern=r"\s*<v>\s*([+-]?\d+\.\d+)\s+([+-]?\d+\.\d+)\s+([+-]?\d+\.\d+)\s<\/v>\s*",
393            footer_pattern=r"\s*<\/varray>",
394            postprocess=lambda x: float(x))
395    if verbose:
396        print('found positions: ', len(positions))
397        print('found forces: ', len(forces))
398    
399    # Take the minimum number of data points across all lists    
400    take_only = min(len(lattices), len(positions), len(energies), len(stresses))
401    lattices = lattices[:take_only]
402    energies = energies[:take_only]
403    positions = positions[:take_only]
404    
405    # Write ML_AB file
406    out = open(output, 'w')
407    _SEP_star  = "**************************************************\n"
408    _SEP_minus = "--------------------------------------------------\n"
409    _SEP_equal = "==================================================\n"
410    out.writelines(' 1.0 Version\n'+_SEP_star)
411    out.writelines('     The number of configurations\n'+_SEP_minus+f'{len(lattices):11d}\n'+_SEP_star)
412    out.writelines('     The maximum number of atom type\n'+_SEP_minus+f'{len(num_sp):8d}\n'+_SEP_star)
413    out.writelines('     The atom types in the data file\n'+_SEP_minus+'     ')
414    for i, el in enumerate(sp):
415        out.writelines(el.ljust(3))
416        if (i%3 == 2) and i<len(sp)-1:
417            out.writelines('\n     ')
418    out.writelines('\n'+_SEP_star)
419    out.writelines('     The maximum number of atoms per system\n'+_SEP_minus+f'{sum(num_sp):8d}\n'+_SEP_star)
420    out.writelines('     The maximum number of atoms per atom type\n'+_SEP_minus+f'{max(num_sp):8d}\n'+_SEP_star)
421    out.writelines('     Reference atomic energy (eV)\n'+_SEP_minus)
422    for i, _ in enumerate(sp):
423        out.writelines('   '+f'{0:18.16f}'+'     ') # dummy energy
424        if (i%3 == 2) and i<len(sp)-1:
425            out.writelines('\n')
426    out.writelines('\n'+_SEP_star)
427    out.writelines('     Atomic mass\n'+_SEP_minus)
428    for i, m in enumerate(masses):
429        out.writelines('   '+f'{m:18.16f}'+'     ')
430        if (i%3 == 2) and i<len(sp)-1:
431            out.writelines('\n')
432    out.writelines('\n'+_SEP_star)
433    out.writelines('     The numbers of basis sets per atom type\n'+_SEP_minus+'    ')
434    for i, _ in enumerate(sp):
435        out.writelines('     1')
436        if (i%3 == 2) and i<len(sp)-1:
437            out.writelines('\n')
438    out.writelines('\n')
439    for i, el in enumerate(sp):
440        out.writelines(_SEP_star+f'     Basis set for {el}\n'+_SEP_minus+'          1      1\n')
441    for i in range(len(lattices)):
442        out.writelines(_SEP_star)
443        out.writelines(f'     Configuration num.{i+1:7d}\n'+_SEP_equal)
444        out.writelines('     System name\n'+_SEP_minus+'     '+system_name+'\n'+_SEP_equal)
445        out.writelines('     The number of atom types\n'+_SEP_minus+f'{len(num_sp):8d}\n'+_SEP_equal)
446        out.writelines('     The number of atoms\n'+_SEP_minus+f'{sum(num_sp):8d}\n'+_SEP_star)
447        out.writelines('     Atom types and atom numbers\n'+_SEP_minus)
448        for el, n in zip(sp, num_sp):
449            out.writelines('     '+el.ljust(3)+f'{n:6d}\n')
450        out.writelines(_SEP_equal)
451        out.writelines('     Primitive lattice vectors (ang.)\n'+_SEP_minus)
452        lattice = lattices[i]
453        for l in lattice:
454            out.writelines('  '+'  '.join([f'{x:18.16f}' for x in l])+'\n')
455        out.writelines(_SEP_equal)
456        out.writelines('     Atomic positions (ang.)\n'+_SEP_minus)
457        if input_type == 'OUTCAR':
458            for l in positions[i]:
459                out.writelines('  '+'  '.join([f'{x:18.16f}' for x in l])+'\n')
460        else:
461            # need to convert to Cartesian coordinates for vasprun.xml inputs
462            for l in positions[i]:
463                out.writelines('  '+'  '.join([f'{x:18.16f}' for x in np.matmul(l, np.array(lattice))])+'\n')
464        out.writelines(_SEP_equal)
465        out.writelines('     Total energy (eV)\n'+_SEP_minus+f'  {energies[i]}\n'+_SEP_equal)
466        out.writelines('     Forces (eV ang.^-1)\n'+_SEP_minus)
467        for l in forces[i]:
468            out.writelines('  '+'  '.join([f'{x:18.16f}' for x in l])+'\n')
469        out.writelines(_SEP_equal)
470        out.writelines('     Stress (kbar)\n'+_SEP_minus+'     XX YY ZZ\n'+_SEP_minus)
471        out.writelines('  '+'  '.join([f'{x:18.16f}' for x in stresses[i][:3]])+'\n'+_SEP_minus)
472        out.writelines('     XY YZ ZX\n'+_SEP_minus)
473        out.writelines('  '+'  '.join([f'{x:18.16f}' for x in stresses[i][3:]])+'\n')

Generate ML_AB formatted output from VASP calculation outputs.

Parameters
  • input (str, optional): Path to the input file, by default 'OUTCAR'
  • input_type (str, optional): Type of the input file, either 'OUTCAR' or 'vasprun', by default 'OUTCAR'
  • output (str, optional): Path to the output file, by default 'ML_AB'
  • overwrite (bool, optional): Flag to overwrite the existing output file, by default False
  • verbose (bool, optional): Flag to print detailed output, by default True
  • system_name (str, optional): Name of the system, by default 'parsed AIMD'
Raises
  • Exception: If input_type is not 'OUTCAR' or 'vasprun'
  • FileExistsError: If the input file does not exist