Skip to content

API Reference

Data Module

cgcnn2.data

AtomCustomJSONInitializer

Bases: AtomInitializer

Initialize atom feature vectors using a JSON file, which is a python dictionary mapping from element number to a list representing the feature vector of the element.

Parameters:

Name Type Description Default
elem_embedding_file str

The path to the .json file

required
Source code in cgcnn2/data.py
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
class AtomCustomJSONInitializer(AtomInitializer):
    """
    Initialize atom feature vectors using a JSON file, which is a python
    dictionary mapping from element number to a list representing the
    feature vector of the element.

    Args:
        elem_embedding_file (str): The path to the .json file
    """

    def __init__(self, elem_embedding_file):
        with open(elem_embedding_file) as f:
            elem_embedding = json.load(f)
        elem_embedding = {int(key): value for key, value in elem_embedding.items()}
        atom_types = set(elem_embedding.keys())
        super().__init__(atom_types)
        for key, value in elem_embedding.items():
            self._embedding[key] = np.array(value, dtype=float)

AtomInitializer

Base class for initializing the vector representation for atoms. Use one AtomInitializer per dataset.

Source code in cgcnn2/data.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
class AtomInitializer:
    """
    Base class for initializing the vector representation for atoms.
    Use one `AtomInitializer` per dataset.
    """

    def __init__(self, atom_types):
        """
        Initialize the atom types and embedding dictionary.

        Args:
            atom_types (set): A set of unique atom types in the dataset.
        """
        self.atom_types = set(atom_types)
        self._embedding = {}

    def get_atom_fea(self, atom_type):
        """
        Get the vector representation for an atom type.

        Args:
            atom_type (str): The type of atom to get the vector representation for.
        """
        assert atom_type in self.atom_types
        return self._embedding[atom_type]

    def load_state_dict(self, state_dict):
        """
        Load the state dictionary for the atom initializer.

        Args:
            state_dict (dict): The state dictionary to load.
        """
        self._embedding = state_dict
        self.atom_types = set(self._embedding.keys())
        self._decodedict = {
            idx: atom_type for atom_type, idx in self._embedding.items()
        }

    def state_dict(self) -> dict:
        """
        Get the state dictionary for the atom initializer.

        Returns:
            dict: The state dictionary.
        """
        return self._embedding

    def decode(self, idx: int) -> str:
        """
        Decode an index to an atom type.

        Args:
            idx (int): The index to decode.

        Returns:
            str: The decoded atom type.
        """
        if not hasattr(self, "_decodedict"):
            self._decodedict = {
                idx: atom_type for atom_type, idx in self._embedding.items()
            }
        return self._decodedict[idx]

__init__(atom_types)

Initialize the atom types and embedding dictionary.

Parameters:

Name Type Description Default
atom_types set

A set of unique atom types in the dataset.

required
Source code in cgcnn2/data.py
112
113
114
115
116
117
118
119
120
def __init__(self, atom_types):
    """
    Initialize the atom types and embedding dictionary.

    Args:
        atom_types (set): A set of unique atom types in the dataset.
    """
    self.atom_types = set(atom_types)
    self._embedding = {}

decode(idx)

Decode an index to an atom type.

Parameters:

Name Type Description Default
idx int

The index to decode.

required

Returns:

Name Type Description
str str

The decoded atom type.

Source code in cgcnn2/data.py
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
def decode(self, idx: int) -> str:
    """
    Decode an index to an atom type.

    Args:
        idx (int): The index to decode.

    Returns:
        str: The decoded atom type.
    """
    if not hasattr(self, "_decodedict"):
        self._decodedict = {
            idx: atom_type for atom_type, idx in self._embedding.items()
        }
    return self._decodedict[idx]

get_atom_fea(atom_type)

Get the vector representation for an atom type.

Parameters:

Name Type Description Default
atom_type str

The type of atom to get the vector representation for.

required
Source code in cgcnn2/data.py
122
123
124
125
126
127
128
129
130
def get_atom_fea(self, atom_type):
    """
    Get the vector representation for an atom type.

    Args:
        atom_type (str): The type of atom to get the vector representation for.
    """
    assert atom_type in self.atom_types
    return self._embedding[atom_type]

load_state_dict(state_dict)

Load the state dictionary for the atom initializer.

Parameters:

Name Type Description Default
state_dict dict

The state dictionary to load.

required
Source code in cgcnn2/data.py
132
133
134
135
136
137
138
139
140
141
142
143
def load_state_dict(self, state_dict):
    """
    Load the state dictionary for the atom initializer.

    Args:
        state_dict (dict): The state dictionary to load.
    """
    self._embedding = state_dict
    self.atom_types = set(self._embedding.keys())
    self._decodedict = {
        idx: atom_type for atom_type, idx in self._embedding.items()
    }

state_dict()

Get the state dictionary for the atom initializer.

Returns:

Name Type Description
dict dict

The state dictionary.

Source code in cgcnn2/data.py
145
146
147
148
149
150
151
152
def state_dict(self) -> dict:
    """
    Get the state dictionary for the atom initializer.

    Returns:
        dict: The state dictionary.
    """
    return self._embedding

CIFData

Bases: Dataset

The CIFData dataset is a wrapper for a dataset where the crystal structures are stored in the form of CIF files.

id_prop.csv: a CSV file with two columns. The first column recodes a unique ID for each crystal, and the second column recodes the value of target property.

atom_init.json: a JSON file that stores the initialization vector for each element.

ID.cif: a CIF file that recodes the crystal structure, where ID is the unique ID for the crystal.

Parameters:

Name Type Description Default
root_dir str

The path to the root directory of the dataset

required
max_num_nbr int

The maximum number of neighbors while constructing the crystal graph

12
radius float

The cutoff radius for searching neighbors

8
dmin float

The minimum distance for constructing GaussianDistance

0
step float

The step size for constructing GaussianDistance

0.2
cache_size int | None

The size of the lru cache for the dataset. Default is None.

None
random_seed int

Random seed for shuffling the dataset

123

Returns:

Name Type Description
atom_fea Tensor

shape (n_i, atom_fea_len)

nbr_fea Tensor

shape (n_i, M, nbr_fea_len)

nbr_fea_idx LongTensor

shape (n_i, M)

target Tensor

shape (1, )

cif_id str or int

Unique ID for the crystal

Source code in cgcnn2/data.py
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
class CIFData(Dataset):
    """
    The CIFData dataset is a wrapper for a dataset where the crystal structures
    are stored in the form of CIF files.

    id_prop.csv: a CSV file with two columns. The first column recodes a
    unique ID for each crystal, and the second column recodes the value of
    target property.

    atom_init.json: a JSON file that stores the initialization vector for each
    element.

    ID.cif: a CIF file that recodes the crystal structure, where ID is the
    unique ID for the crystal.

    Args:
        root_dir (str): The path to the root directory of the dataset
        max_num_nbr (int): The maximum number of neighbors while constructing the crystal graph
        radius (float): The cutoff radius for searching neighbors
        dmin (float): The minimum distance for constructing GaussianDistance
        step (float): The step size for constructing GaussianDistance
        cache_size (int | None): The size of the lru cache for the dataset. Default is None.
        random_seed (int): Random seed for shuffling the dataset

    Returns:
        atom_fea (torch.Tensor): shape (n_i, atom_fea_len)
        nbr_fea (torch.Tensor): shape (n_i, M, nbr_fea_len)
        nbr_fea_idx (torch.LongTensor): shape (n_i, M)
        target (torch.Tensor): shape (1, )
        cif_id (str or int): Unique ID for the crystal
    """

    def __init__(
        self,
        root_dir,
        max_num_nbr=12,
        radius=8,
        dmin=0,
        step=0.2,
        cache_size=None,
        random_seed=123,
    ):
        self.root_dir = root_dir
        self.max_num_nbr, self.radius = max_num_nbr, radius
        assert os.path.exists(root_dir), "root_dir does not exist!"
        id_prop_file = os.path.join(self.root_dir, "id_prop.csv")
        assert os.path.exists(id_prop_file), "id_prop.csv does not exist!"
        with open(id_prop_file) as f:
            reader = csv.reader(f)
            self.id_prop_data = [row for row in reader]
        random.seed(random_seed)
        atom_init_file = os.path.join(self.root_dir, "atom_init.json")
        assert os.path.exists(atom_init_file), "atom_init.json does not exist!"
        self.ari = AtomCustomJSONInitializer(atom_init_file)
        self.gdf = GaussianDistance(dmin=dmin, dmax=self.radius, step=step)
        self._raw_load_item = self._load_item_fast
        self.cache_size = cache_size
        self._configure_cache()

    def set_cache_size(self, cache_size: Optional[int]) -> None:
        """
        Change the LRU-cache capacity on the fly.

        Args:
            cache_size (int | None): The size of the cache to set, None for unlimited size. Default is None.
        """
        self.cache_size = cache_size
        if hasattr(self._cache_load, "cache_clear"):
            self._cache_load.cache_clear()
        self._configure_cache()

    def clear_cache(self) -> None:
        """
        Clear the current cache.
        """
        if hasattr(self._cache_load, "cache_clear"):
            self._cache_load.cache_clear()

    def __len__(self):
        return len(self.id_prop_data)

    def __getitem__(self, idx):
        return self._cache_load(idx)

    def _configure_cache(self) -> None:
        """
        Wrap `_raw_load_item` with an LRU cache.
        """
        if self.cache_size is None:
            self._cache_load = functools.lru_cache(maxsize=None)(self._raw_load_item)
        elif self.cache_size <= 0:
            self._cache_load = self._raw_load_item
        else:
            self._cache_load = functools.lru_cache(maxsize=self.cache_size)(
                self._raw_load_item
            )

    def _load_item(self, idx):
        cif_id, target = self.id_prop_data[idx]
        crystal = Structure.from_file(os.path.join(self.root_dir, cif_id + ".cif"))
        atom_fea = np.vstack(
            [
                self.ari.get_atom_fea(crystal[i].specie.number)
                for i in range(len(crystal))
            ]
        )
        atom_fea = torch.Tensor(atom_fea)
        all_nbrs = crystal.get_all_neighbors(self.radius, include_index=True)
        all_nbrs = [sorted(nbrs, key=lambda x: x[1]) for nbrs in all_nbrs]
        nbr_fea_idx, nbr_fea = [], []
        for nbr in all_nbrs:
            if len(nbr) < self.max_num_nbr:
                warnings.warn(
                    "{} not find enough neighbors to build graph. "
                    "If it happens frequently, consider increase "
                    "radius.".format(cif_id),
                    stacklevel=2,
                )
                nbr_fea_idx.append(
                    list(map(lambda x: x[2], nbr)) + [0] * (self.max_num_nbr - len(nbr))
                )
                nbr_fea.append(
                    list(map(lambda x: x[1], nbr))
                    + [self.radius + 1.0] * (self.max_num_nbr - len(nbr))
                )
            else:
                nbr_fea_idx.append(list(map(lambda x: x[2], nbr[: self.max_num_nbr])))
                nbr_fea.append(list(map(lambda x: x[1], nbr[: self.max_num_nbr])))
        nbr_fea_idx, nbr_fea = np.array(nbr_fea_idx), np.array(nbr_fea)
        nbr_fea = self.gdf.expand(nbr_fea)
        atom_fea = torch.Tensor(atom_fea)
        nbr_fea = torch.Tensor(nbr_fea)
        nbr_fea_idx = torch.LongTensor(nbr_fea_idx)
        target = torch.Tensor([float(target)])
        return (atom_fea, nbr_fea, nbr_fea_idx), target, cif_id

    def _load_item_fast(self, idx):
        cif_id, target = self.id_prop_data[idx]
        crystal = Structure.from_file(os.path.join(self.root_dir, cif_id + ".cif"))
        atom_fea = np.vstack(
            [
                self.ari.get_atom_fea(crystal[i].specie.number)
                for i in range(len(crystal))
            ]
        )
        atom_fea = torch.Tensor(atom_fea)
        center_idx, neigh_idx, _images, dists = crystal.get_neighbor_list(self.radius)
        n_sites = len(crystal)
        bucket = [[] for _ in range(n_sites)]
        for c, n, d in zip(center_idx, neigh_idx, dists):
            bucket[c].append((n, d))
        bucket = [sorted(lst, key=lambda x: x[1]) for lst in bucket]
        nbr_fea_idx, nbr_fea = [], []
        for lst in bucket:
            if len(lst) < self.max_num_nbr:
                warnings.warn(
                    f"{cif_id} not find enough neighbors to build graph. "
                    "If it happens frequently, consider increase "
                    "radius.",
                    stacklevel=2,
                )
            idxs = [t[0] for t in lst[: self.max_num_nbr]]
            dvec = [t[1] for t in lst[: self.max_num_nbr]]
            pad = self.max_num_nbr - len(idxs)
            nbr_fea_idx.append(idxs + [0] * pad)
            nbr_fea.append(dvec + [self.radius + 1.0] * pad)
        nbr_fea_idx = torch.as_tensor(np.array(nbr_fea_idx), dtype=torch.long)
        nbr_fea = self.gdf.expand(np.array(nbr_fea))
        nbr_fea = torch.Tensor(nbr_fea)
        target = torch.tensor([float(target)])
        return (atom_fea, nbr_fea, nbr_fea_idx), target, cif_id

clear_cache()

Clear the current cache.

Source code in cgcnn2/data.py
262
263
264
265
266
267
def clear_cache(self) -> None:
    """
    Clear the current cache.
    """
    if hasattr(self._cache_load, "cache_clear"):
        self._cache_load.cache_clear()

set_cache_size(cache_size)

Change the LRU-cache capacity on the fly.

Parameters:

Name Type Description Default
cache_size int | None

The size of the cache to set, None for unlimited size. Default is None.

required
Source code in cgcnn2/data.py
250
251
252
253
254
255
256
257
258
259
260
def set_cache_size(self, cache_size: Optional[int]) -> None:
    """
    Change the LRU-cache capacity on the fly.

    Args:
        cache_size (int | None): The size of the cache to set, None for unlimited size. Default is None.
    """
    self.cache_size = cache_size
    if hasattr(self._cache_load, "cache_clear"):
        self._cache_load.cache_clear()
    self._configure_cache()

CIFData_NoTarget

Bases: Dataset

The CIFData_NoTarget dataset is a wrapper for a dataset where the crystal structures are stored in the form of CIF files.

atom_init.json: a JSON file that stores the initialization vector for each element.

ID.cif: a CIF file that recodes the crystal structure, where ID is the unique ID for the crystal.

Parameters:

Name Type Description Default
root_dir str

The path to the root directory of the dataset

required
max_num_nbr int

The maximum number of neighbors while constructing the crystal graph

12
radius float

The cutoff radius for searching neighbors

8
dmin float

The minimum distance for constructing GaussianDistance

0
step float

The step size for constructing GaussianDistance

0.2
random_seed int

Random seed for shuffling the dataset

123

Returns:

Name Type Description
atom_fea Tensor

shape (n_i, atom_fea_len)

nbr_fea Tensor

shape (n_i, M, nbr_fea_len)

nbr_fea_idx LongTensor

shape (n_i, M)

target Tensor

shape (1, )

cif_id str or int

Unique ID for the crystal

Source code in cgcnn2/data.py
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
class CIFData_NoTarget(Dataset):
    """
    The CIFData_NoTarget dataset is a wrapper for a dataset where the crystal
    structures are stored in the form of CIF files.

    atom_init.json: a JSON file that stores the initialization vector for each
    element.

    ID.cif: a CIF file that recodes the crystal structure, where ID is the
    unique ID for the crystal.

    Args:
        root_dir (str): The path to the root directory of the dataset
        max_num_nbr (int): The maximum number of neighbors while constructing the crystal graph
        radius (float): The cutoff radius for searching neighbors
        dmin (float): The minimum distance for constructing GaussianDistance
        step (float): The step size for constructing GaussianDistance
        random_seed (int): Random seed for shuffling the dataset

    Returns:
        atom_fea (torch.Tensor): shape (n_i, atom_fea_len)
        nbr_fea (torch.Tensor): shape (n_i, M, nbr_fea_len)
        nbr_fea_idx (torch.LongTensor): shape (n_i, M)
        target (torch.Tensor): shape (1, )
        cif_id (str or int): Unique ID for the crystal
    """

    def __init__(
        self, root_dir, max_num_nbr=12, radius=8, dmin=0, step=0.2, random_seed=123
    ):
        self.root_dir = root_dir
        self.max_num_nbr, self.radius = max_num_nbr, radius
        assert os.path.exists(root_dir), "root_dir does not exist!"
        id_prop_data = []
        for file in os.listdir(root_dir):
            if file.endswith(".cif"):
                id_prop_data.append(file[:-4])
        id_prop_data = [(cif_id, 0) for cif_id in id_prop_data]
        id_prop_data.sort(key=lambda x: x[0])
        self.id_prop_data = id_prop_data
        random.seed(random_seed)
        atom_init_file = os.path.join(self.root_dir, "atom_init.json")
        assert os.path.exists(atom_init_file), "atom_init.json does not exist!"
        self.ari = AtomCustomJSONInitializer(atom_init_file)
        self.gdf = GaussianDistance(dmin=dmin, dmax=self.radius, step=step)

    def __len__(self):
        return len(self.id_prop_data)

    @functools.lru_cache(maxsize=None)  # Cache loaded structures
    def __getitem__(self, idx):
        cif_id, target = self.id_prop_data[idx]
        crystal = Structure.from_file(os.path.join(self.root_dir, cif_id + ".cif"))
        atom_fea = np.vstack(
            [
                self.ari.get_atom_fea(crystal[i].specie.number)
                for i in range(len(crystal))
            ]
        )
        atom_fea = torch.Tensor(atom_fea)
        all_nbrs = crystal.get_all_neighbors(self.radius, include_index=True)
        all_nbrs = [sorted(nbrs, key=lambda x: x[1]) for nbrs in all_nbrs]
        nbr_fea_idx, nbr_fea = [], []
        for nbr in all_nbrs:
            if len(nbr) < self.max_num_nbr:
                warnings.warn(
                    "{} not find enough neighbors to build graph. "
                    "If it happens frequently, consider increase "
                    "radius.".format(cif_id),
                    stacklevel=2,
                )
                nbr_fea_idx.append(
                    list(map(lambda x: x[2], nbr)) + [0] * (self.max_num_nbr - len(nbr))
                )
                nbr_fea.append(
                    list(map(lambda x: x[1], nbr))
                    + [self.radius + 1.0] * (self.max_num_nbr - len(nbr))
                )
            else:
                nbr_fea_idx.append(list(map(lambda x: x[2], nbr[: self.max_num_nbr])))
                nbr_fea.append(list(map(lambda x: x[1], nbr[: self.max_num_nbr])))
        nbr_fea_idx, nbr_fea = np.array(nbr_fea_idx), np.array(nbr_fea)
        nbr_fea = self.gdf.expand(nbr_fea)
        atom_fea = torch.Tensor(atom_fea)
        nbr_fea = torch.Tensor(nbr_fea)
        nbr_fea_idx = torch.LongTensor(nbr_fea_idx)
        target = torch.Tensor([float(target)])
        return (atom_fea, nbr_fea, nbr_fea_idx), target, cif_id

GaussianDistance

Expands the distance by Gaussian basis.

Unit: angstrom

Source code in cgcnn2/data.py
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
class GaussianDistance:
    """
    Expands the distance by Gaussian basis.

    Unit: angstrom
    """

    def __init__(self, dmin, dmax, step, var=None):
        """
        Args:
            dmin (float): Minimum interatomic distance (center of the first Gaussian).
            dmax (float): Maximum interatomic distance (center of the last Gaussian).
            step (float): Spacing between consecutive Gaussian centers.
            var (float, optional): Variance of each Gaussian. If None, defaults to step.
        """

        assert dmin < dmax
        assert dmax - dmin > step
        self.filter = np.arange(dmin, dmax + step, step)
        if var is None:
            var = step
        self.var = var

    def expand(self, distances):
        """
        Project each scalar distance onto a set of Gaussian basis functions.

        Args:
            distances (np.ndarray): An array of interatomic distances.

        Returns:
            expanded_distance (np.ndarray): An array where the last dimension contains the Gaussian basis values for each input distance.
        """

        expanded_distance = np.exp(
            -((distances[..., np.newaxis] - self.filter) ** 2) / self.var**2
        )
        return expanded_distance

__init__(dmin, dmax, step, var=None)

Parameters:

Name Type Description Default
dmin float

Minimum interatomic distance (center of the first Gaussian).

required
dmax float

Maximum interatomic distance (center of the last Gaussian).

required
step float

Spacing between consecutive Gaussian centers.

required
var float

Variance of each Gaussian. If None, defaults to step.

None
Source code in cgcnn2/data.py
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def __init__(self, dmin, dmax, step, var=None):
    """
    Args:
        dmin (float): Minimum interatomic distance (center of the first Gaussian).
        dmax (float): Maximum interatomic distance (center of the last Gaussian).
        step (float): Spacing between consecutive Gaussian centers.
        var (float, optional): Variance of each Gaussian. If None, defaults to step.
    """

    assert dmin < dmax
    assert dmax - dmin > step
    self.filter = np.arange(dmin, dmax + step, step)
    if var is None:
        var = step
    self.var = var

expand(distances)

Project each scalar distance onto a set of Gaussian basis functions.

Parameters:

Name Type Description Default
distances ndarray

An array of interatomic distances.

required

Returns:

Name Type Description
expanded_distance ndarray

An array where the last dimension contains the Gaussian basis values for each input distance.

Source code in cgcnn2/data.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
def expand(self, distances):
    """
    Project each scalar distance onto a set of Gaussian basis functions.

    Args:
        distances (np.ndarray): An array of interatomic distances.

    Returns:
        expanded_distance (np.ndarray): An array where the last dimension contains the Gaussian basis values for each input distance.
    """

    expanded_distance = np.exp(
        -((distances[..., np.newaxis] - self.filter) ** 2) / self.var**2
    )
    return expanded_distance

collate_pool(dataset_list)

Collate a list of data and return a batch for predicting crystal properties.

Parameters:

Name Type Description Default
dataset_list list of tuples

List of tuples for each data point. Each tuple contains:

required
atom_fea Tensor

shape (n_i, atom_fea_len) Atom features for each atom in the crystal

required
nbr_fea Tensor

shape (n_i, M, nbr_fea_len) Bond features for each atom's M neighbors

required
nbr_fea_idx LongTensor

shape (n_i, M) Indices of M neighbors of each atom

required
target Tensor

shape (1, ) Target value for prediction

required

Returns:

Name Type Description
batch_atom_fea Tensor

shape (N, orig_atom_fea_len) Atom features from atom type

batch_nbr_fea Tensor

shape (N, M, nbr_fea_len) Bond features of each atom's M neighbors

batch_nbr_fea_idx LongTensor

shape (N, M) Indices of M neighbors of each atom

crystal_atom_idx list of torch.LongTensor

length N0 Mapping from the crystal idx to atom idx

batch_target Tensor

shape (N, 1) Target value for prediction

batch_cif_ids list of str or int

Unique IDs for each crystal

Source code in cgcnn2/data.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def collate_pool(dataset_list):
    """
    Collate a list of data and return a batch for predicting crystal properties.

    Args:
        dataset_list (list of tuples): List of tuples for each data point. Each tuple contains:
        atom_fea (torch.Tensor): shape (n_i, atom_fea_len) Atom features for each atom in the crystal
        nbr_fea (torch.Tensor): shape (n_i, M, nbr_fea_len) Bond features for each atom's M neighbors
        nbr_fea_idx (torch.LongTensor): shape (n_i, M) Indices of M neighbors of each atom
        target (torch.Tensor): shape (1, ) Target value for prediction
        cif_id (str or int) Unique ID for the crystal

    Returns:
        batch_atom_fea (torch.Tensor): shape (N, orig_atom_fea_len) Atom features from atom type
        batch_nbr_fea (torch.Tensor): shape (N, M, nbr_fea_len) Bond features of each atom's M neighbors
        batch_nbr_fea_idx (torch.LongTensor): shape (N, M) Indices of M neighbors of each atom
        crystal_atom_idx (list of torch.LongTensor): length N0 Mapping from the crystal idx to atom idx
        batch_target (torch.Tensor): shape (N, 1) Target value for prediction
        batch_cif_ids (list of str or int): Unique IDs for each crystal
    """

    batch_atom_fea, batch_nbr_fea, batch_nbr_fea_idx = [], [], []
    crystal_atom_idx, batch_target = [], []
    batch_cif_ids = []
    base_idx = 0
    for (atom_fea, nbr_fea, nbr_fea_idx), target, cif_id in dataset_list:
        n_i = atom_fea.shape[0]  # number of atoms for this crystal
        batch_atom_fea.append(atom_fea)
        batch_nbr_fea.append(nbr_fea)
        batch_nbr_fea_idx.append(nbr_fea_idx + base_idx)
        new_idx = torch.LongTensor(np.arange(n_i) + base_idx)
        crystal_atom_idx.append(new_idx)
        batch_target.append(target)
        batch_cif_ids.append(cif_id)
        base_idx += n_i
    return (
        (
            torch.cat(batch_atom_fea, dim=0),
            torch.cat(batch_nbr_fea, dim=0),
            torch.cat(batch_nbr_fea_idx, dim=0),
            crystal_atom_idx,
        ),
        torch.stack(batch_target, dim=0),
        batch_cif_ids,
    )

full_set_split(full_set_dir, train_ratio, valid_ratio, train_force_dir=None, random_seed=0)

Split the full set into train, valid, and test sets into a temporary directory.

Parameters:

Name Type Description Default
full_set_dir str

The path to the full set

required
train_ratio float

The ratio of the training set

required
valid_ratio float

The ratio of the validation set

required
train_force_dir str

The path to the forced training set. Adding this will no longer keep the original split ratio.

None
random_seed int

The random seed for the split

0

Returns:

Name Type Description
train_dir str

The path to a temporary directory containing the train set

valid_dir str

The path to a temporary directory containing the valid set

test_dir str

The path to a temporary directory containing the test set

Source code in cgcnn2/data.py
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
def full_set_split(
    full_set_dir: str,
    train_ratio: float,
    valid_ratio: float,
    train_force_dir: str | None = None,
    random_seed: int = 0,
):
    """
    Split the full set into train, valid, and test sets into a temporary directory.

    Args:
        full_set_dir (str): The path to the full set
        train_ratio (float): The ratio of the training set
        valid_ratio (float): The ratio of the validation set
        train_force_dir (str): The path to the forced training set. Adding this will no longer keep the original split ratio.
        random_seed (int): The random seed for the split

    Returns:
        train_dir (str): The path to a temporary directory containing the train set
        valid_dir (str): The path to a temporary directory containing the valid set
        test_dir (str): The path to a temporary directory containing the test set
    """
    df = pd.read_csv(
        os.path.join(full_set_dir, "id_prop.csv"),
        header=None,
        names=["cif_id", "property"],
    )

    rng = np.random.RandomState(random_seed)
    df_shuffle = df.sample(frac=1.0, random_state=rng).reset_index(drop=True)

    n_total = len(df_shuffle)
    n_train = int(n_total * train_ratio)
    n_valid = int(n_total * valid_ratio)

    train_df = df_shuffle[:n_train]
    valid_df = df_shuffle[n_train : n_train + n_valid]
    test_df = df_shuffle[n_train + n_valid :]

    temp_train_dir = tempfile.mkdtemp()
    temp_valid_dir = tempfile.mkdtemp()
    temp_test_dir = tempfile.mkdtemp()

    atexit.register(shutil.rmtree, temp_train_dir, ignore_errors=True)
    atexit.register(shutil.rmtree, temp_valid_dir, ignore_errors=True)
    atexit.register(shutil.rmtree, temp_test_dir, ignore_errors=True)

    splits = {
        temp_train_dir: train_df,
        temp_valid_dir: valid_df,
        temp_test_dir: test_df,
    }

    for temp_dir, df in splits.items():
        for cif_id in df["cif_id"]:
            src = os.path.join(full_set_dir, f"{cif_id}.cif")
            dst = os.path.join(temp_dir, f"{cif_id}.cif")
            shutil.copy(src, dst)

    train_df.to_csv(
        os.path.join(temp_train_dir, "id_prop.csv"), index=False, header=False
    )
    valid_df.to_csv(
        os.path.join(temp_valid_dir, "id_prop.csv"), index=False, header=False
    )
    test_df.to_csv(
        os.path.join(temp_test_dir, "id_prop.csv"), index=False, header=False
    )

    shutil.copy(os.path.join(full_set_dir, "atom_init.json"), temp_train_dir)
    shutil.copy(os.path.join(full_set_dir, "atom_init.json"), temp_valid_dir)
    shutil.copy(os.path.join(full_set_dir, "atom_init.json"), temp_test_dir)

    if train_force_dir is not None:
        df_force = pd.read_csv(
            os.path.join(train_force_dir, "id_prop.csv"),
            header=None,
            names=["cif_id", "property"],
        )

        train_df = pd.concat([train_df, df_force])
        train_df.to_csv(
            os.path.join(temp_train_dir, "id_prop.csv"), index=False, header=False
        )

        for cif_id in df_force["cif_id"]:
            src = os.path.join(train_force_dir, f"{cif_id}.cif")
            dst = os.path.join(temp_train_dir, f"{cif_id}.cif")
            shutil.copy(src, dst)

    return temp_train_dir, temp_valid_dir, temp_test_dir

Model Framework

cgcnn2.model

ConvLayer

Bases: Module

Convolutional layer for graph data.

Performs a convolutional operation on graphs, updating atom features based on their neighbors.

Source code in cgcnn2/model.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
class ConvLayer(nn.Module):
    """
    Convolutional layer for graph data.

    Performs a convolutional operation on graphs, updating atom features based on their neighbors.
    """

    def __init__(self, atom_fea_len: int, nbr_fea_len: int) -> None:
        """
        Initialize the ConvLayer.

        Args:
            atom_fea_len (int): Number of atom hidden features.
            nbr_fea_len (int): Number of bond (neighbor) features.
        """
        super(ConvLayer, self).__init__()
        self.atom_fea_len = atom_fea_len
        self.nbr_fea_len = nbr_fea_len
        self.fc_full = nn.Linear(
            2 * self.atom_fea_len + self.nbr_fea_len, 2 * self.atom_fea_len
        )
        self.sigmoid = nn.Sigmoid()
        self.softplus1 = nn.Softplus()
        self.bn1 = nn.BatchNorm1d(2 * self.atom_fea_len)
        self.bn2 = nn.BatchNorm1d(self.atom_fea_len)
        self.softplus2 = nn.Softplus()

    def forward(self, atom_in_fea, nbr_fea, nbr_fea_idx):
        """
        Forward pass

        N: Total number of atoms in the batch
        M: Max number of neighbors

        Args:
            atom_in_fea (torch.Tensor): Variable(torch.Tensor) shape (N, atom_fea_len)
              Atom hidden features before convolution
            nbr_fea (torch.Tensor): Variable(torch.Tensor) shape (N, M, nbr_fea_len)
              Bond features of each atom's M neighbors
            nbr_fea_idx (torch.LongTensor): shape (N, M)
              Indices of M neighbors of each atom

        Returns:
            atom_out_fea (nn.Variable): shape (N, atom_fea_len)
              Atom hidden features after convolution

        """
        N, M = nbr_fea_idx.shape
        # convolution
        atom_nbr_fea = atom_in_fea[nbr_fea_idx, :]
        total_nbr_fea = torch.cat(
            [
                atom_in_fea.unsqueeze(1).expand(N, M, self.atom_fea_len),
                atom_nbr_fea,
                nbr_fea,
            ],
            dim=2,
        )
        total_gated_fea = self.fc_full(total_nbr_fea)
        total_gated_fea = self.bn1(
            total_gated_fea.view(-1, self.atom_fea_len * 2)
        ).view(N, M, self.atom_fea_len * 2)
        nbr_filter, nbr_core = total_gated_fea.chunk(2, dim=2)
        nbr_filter = self.sigmoid(nbr_filter)
        nbr_core = self.softplus1(nbr_core)
        nbr_sumed = torch.sum(nbr_filter * nbr_core, dim=1)
        nbr_sumed = self.bn2(nbr_sumed)
        out = self.softplus2(atom_in_fea + nbr_sumed)
        return out

__init__(atom_fea_len, nbr_fea_len)

Initialize the ConvLayer.

Parameters:

Name Type Description Default
atom_fea_len int

Number of atom hidden features.

required
nbr_fea_len int

Number of bond (neighbor) features.

required
Source code in cgcnn2/model.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def __init__(self, atom_fea_len: int, nbr_fea_len: int) -> None:
    """
    Initialize the ConvLayer.

    Args:
        atom_fea_len (int): Number of atom hidden features.
        nbr_fea_len (int): Number of bond (neighbor) features.
    """
    super(ConvLayer, self).__init__()
    self.atom_fea_len = atom_fea_len
    self.nbr_fea_len = nbr_fea_len
    self.fc_full = nn.Linear(
        2 * self.atom_fea_len + self.nbr_fea_len, 2 * self.atom_fea_len
    )
    self.sigmoid = nn.Sigmoid()
    self.softplus1 = nn.Softplus()
    self.bn1 = nn.BatchNorm1d(2 * self.atom_fea_len)
    self.bn2 = nn.BatchNorm1d(self.atom_fea_len)
    self.softplus2 = nn.Softplus()

forward(atom_in_fea, nbr_fea, nbr_fea_idx)

Forward pass

N: Total number of atoms in the batch M: Max number of neighbors

Parameters:

Name Type Description Default
atom_in_fea Tensor

Variable(torch.Tensor) shape (N, atom_fea_len) Atom hidden features before convolution

required
nbr_fea Tensor

Variable(torch.Tensor) shape (N, M, nbr_fea_len) Bond features of each atom's M neighbors

required
nbr_fea_idx LongTensor

shape (N, M) Indices of M neighbors of each atom

required

Returns:

Name Type Description
atom_out_fea Variable

shape (N, atom_fea_len) Atom hidden features after convolution

Source code in cgcnn2/model.py
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def forward(self, atom_in_fea, nbr_fea, nbr_fea_idx):
    """
    Forward pass

    N: Total number of atoms in the batch
    M: Max number of neighbors

    Args:
        atom_in_fea (torch.Tensor): Variable(torch.Tensor) shape (N, atom_fea_len)
          Atom hidden features before convolution
        nbr_fea (torch.Tensor): Variable(torch.Tensor) shape (N, M, nbr_fea_len)
          Bond features of each atom's M neighbors
        nbr_fea_idx (torch.LongTensor): shape (N, M)
          Indices of M neighbors of each atom

    Returns:
        atom_out_fea (nn.Variable): shape (N, atom_fea_len)
          Atom hidden features after convolution

    """
    N, M = nbr_fea_idx.shape
    # convolution
    atom_nbr_fea = atom_in_fea[nbr_fea_idx, :]
    total_nbr_fea = torch.cat(
        [
            atom_in_fea.unsqueeze(1).expand(N, M, self.atom_fea_len),
            atom_nbr_fea,
            nbr_fea,
        ],
        dim=2,
    )
    total_gated_fea = self.fc_full(total_nbr_fea)
    total_gated_fea = self.bn1(
        total_gated_fea.view(-1, self.atom_fea_len * 2)
    ).view(N, M, self.atom_fea_len * 2)
    nbr_filter, nbr_core = total_gated_fea.chunk(2, dim=2)
    nbr_filter = self.sigmoid(nbr_filter)
    nbr_core = self.softplus1(nbr_core)
    nbr_sumed = torch.sum(nbr_filter * nbr_core, dim=1)
    nbr_sumed = self.bn2(nbr_sumed)
    out = self.softplus2(atom_in_fea + nbr_sumed)
    return out

CrystalGraphConvNet

Bases: Module

Create a crystal graph convolutional neural network for predicting total material properties.

Source code in cgcnn2/model.py
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
class CrystalGraphConvNet(nn.Module):
    """
    Create a crystal graph convolutional neural network for predicting total
    material properties.
    """

    def __init__(
        self,
        orig_atom_fea_len: int,
        nbr_fea_len: int,
        atom_fea_len: int = 64,
        n_conv: int = 3,
        h_fea_len: int = 128,
        n_h: int = 1,
        classification: bool = False,
    ) -> None:
        """
        Initialize CrystalGraphConvNet.

        Args:
            orig_atom_fea_len (int): Number of atom features in the input.
            nbr_fea_len (int): Number of bond features.
            atom_fea_len (int): Number of hidden atom features in the convolutional layers
            n_conv (int): Number of convolutional layers
            h_fea_len (int): Number of hidden features after pooling
            n_h (int): Number of hidden layers after pooling
            classification (bool): Whether to use classification or regression
        """
        super(CrystalGraphConvNet, self).__init__()
        self.classification = classification
        self.embedding = nn.Linear(orig_atom_fea_len, atom_fea_len)
        self.convs = nn.ModuleList(
            [
                ConvLayer(atom_fea_len=atom_fea_len, nbr_fea_len=nbr_fea_len)
                for _ in range(n_conv)
            ]
        )
        self.conv_to_fc = nn.Linear(atom_fea_len, h_fea_len)
        self.conv_to_fc_softplus = nn.Softplus()
        if n_h > 1:
            self.fcs = nn.ModuleList(
                [nn.Linear(h_fea_len, h_fea_len) for _ in range(n_h - 1)]
            )
            self.softpluses = nn.ModuleList([nn.Softplus() for _ in range(n_h - 1)])

        if self.classification:
            self.fc_out = nn.Linear(h_fea_len, 2)
        else:
            self.fc_out = nn.Linear(h_fea_len, 1)

        if self.classification:
            self.logsoftmax = nn.LogSoftmax(dim=1)
            self.dropout = nn.Dropout()

    def forward(
        self,
        atom_fea: torch.Tensor,
        nbr_fea: torch.Tensor,
        nbr_fea_idx: torch.LongTensor,
        crystal_atom_idx: list[torch.LongTensor],
    ):
        """
        Forward pass

        N: Total number of atoms in the batch
        M: Max number of neighbors
        N0: Total number of crystals in the batch

        Args:
            atom_fea (torch.Tensor): Variable(torch.Tensor) shape (N, orig_atom_fea_len)
              Atom features from atom type
            nbr_fea (torch.Tensor): Variable(torch.Tensor) shape (N, M, nbr_fea_len)
              Bond features of each atom's M neighbors
            nbr_fea_idx (torch.LongTensor): shape (N, M)
              Indices of M neighbors of each atom
            crystal_atom_idx (list of torch.LongTensor): Mapping from the crystal idx to atom idx

        Returns:
            prediction (nn.Variable): shape (N, )
              Atom hidden features after convolution

        """
        atom_fea = self.embedding(atom_fea)
        for conv_func in self.convs:
            atom_fea = conv_func(atom_fea, nbr_fea, nbr_fea_idx)
        crys_fea = self.pooling(atom_fea, crystal_atom_idx)
        crys_fea = self.conv_to_fc(self.conv_to_fc_softplus(crys_fea))
        crys_fea = self.conv_to_fc_softplus(crys_fea)
        if self.classification:
            crys_fea = self.dropout(crys_fea)
        if hasattr(self, "fcs") and hasattr(self, "softpluses"):
            for fc, softplus in zip(self.fcs, self.softpluses):
                crys_fea = softplus(fc(crys_fea))
        out = self.fc_out(crys_fea)
        if self.classification:
            out = self.logsoftmax(out)
        return out, crys_fea

    def pooling(
        self, atom_fea: torch.Tensor, crystal_atom_idx: list[torch.LongTensor]
    ) -> torch.Tensor:
        """
        Aggregate atom features into crystal-level features by mean pooling.

        Args:
            atom_fea (torch.Tensor): shape (N, atom_fea_len)
                Atom embeddings for all atoms in the batch.
            crystal_atom_idx (list[torch.LongTensor]):
                crystal_atom_idx[i] contains the indices of atoms belonging
                to the i-th crystal.  The concatenated indices must cover
                every atom exactly once.

        Returns:
            mean_fea (torch.Tensor): shape (n_crystals, atom_fea_len)
                Mean-pooled crystal embeddings.
        """
        assert sum(len(idx) for idx in crystal_atom_idx) == atom_fea.shape[0]
        mean_fea = torch.stack(
            [torch.mean(atom_fea[idx], dim=0) for idx in crystal_atom_idx], dim=0
        )
        return mean_fea

__init__(orig_atom_fea_len, nbr_fea_len, atom_fea_len=64, n_conv=3, h_fea_len=128, n_h=1, classification=False)

Initialize CrystalGraphConvNet.

Parameters:

Name Type Description Default
orig_atom_fea_len int

Number of atom features in the input.

required
nbr_fea_len int

Number of bond features.

required
atom_fea_len int

Number of hidden atom features in the convolutional layers

64
n_conv int

Number of convolutional layers

3
h_fea_len int

Number of hidden features after pooling

128
n_h int

Number of hidden layers after pooling

1
classification bool

Whether to use classification or regression

False
Source code in cgcnn2/model.py
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def __init__(
    self,
    orig_atom_fea_len: int,
    nbr_fea_len: int,
    atom_fea_len: int = 64,
    n_conv: int = 3,
    h_fea_len: int = 128,
    n_h: int = 1,
    classification: bool = False,
) -> None:
    """
    Initialize CrystalGraphConvNet.

    Args:
        orig_atom_fea_len (int): Number of atom features in the input.
        nbr_fea_len (int): Number of bond features.
        atom_fea_len (int): Number of hidden atom features in the convolutional layers
        n_conv (int): Number of convolutional layers
        h_fea_len (int): Number of hidden features after pooling
        n_h (int): Number of hidden layers after pooling
        classification (bool): Whether to use classification or regression
    """
    super(CrystalGraphConvNet, self).__init__()
    self.classification = classification
    self.embedding = nn.Linear(orig_atom_fea_len, atom_fea_len)
    self.convs = nn.ModuleList(
        [
            ConvLayer(atom_fea_len=atom_fea_len, nbr_fea_len=nbr_fea_len)
            for _ in range(n_conv)
        ]
    )
    self.conv_to_fc = nn.Linear(atom_fea_len, h_fea_len)
    self.conv_to_fc_softplus = nn.Softplus()
    if n_h > 1:
        self.fcs = nn.ModuleList(
            [nn.Linear(h_fea_len, h_fea_len) for _ in range(n_h - 1)]
        )
        self.softpluses = nn.ModuleList([nn.Softplus() for _ in range(n_h - 1)])

    if self.classification:
        self.fc_out = nn.Linear(h_fea_len, 2)
    else:
        self.fc_out = nn.Linear(h_fea_len, 1)

    if self.classification:
        self.logsoftmax = nn.LogSoftmax(dim=1)
        self.dropout = nn.Dropout()

forward(atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx)

Forward pass

N: Total number of atoms in the batch M: Max number of neighbors N0: Total number of crystals in the batch

Parameters:

Name Type Description Default
atom_fea Tensor

Variable(torch.Tensor) shape (N, orig_atom_fea_len) Atom features from atom type

required
nbr_fea Tensor

Variable(torch.Tensor) shape (N, M, nbr_fea_len) Bond features of each atom's M neighbors

required
nbr_fea_idx LongTensor

shape (N, M) Indices of M neighbors of each atom

required
crystal_atom_idx list of torch.LongTensor

Mapping from the crystal idx to atom idx

required

Returns:

Name Type Description
prediction Variable

shape (N, ) Atom hidden features after convolution

Source code in cgcnn2/model.py
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
def forward(
    self,
    atom_fea: torch.Tensor,
    nbr_fea: torch.Tensor,
    nbr_fea_idx: torch.LongTensor,
    crystal_atom_idx: list[torch.LongTensor],
):
    """
    Forward pass

    N: Total number of atoms in the batch
    M: Max number of neighbors
    N0: Total number of crystals in the batch

    Args:
        atom_fea (torch.Tensor): Variable(torch.Tensor) shape (N, orig_atom_fea_len)
          Atom features from atom type
        nbr_fea (torch.Tensor): Variable(torch.Tensor) shape (N, M, nbr_fea_len)
          Bond features of each atom's M neighbors
        nbr_fea_idx (torch.LongTensor): shape (N, M)
          Indices of M neighbors of each atom
        crystal_atom_idx (list of torch.LongTensor): Mapping from the crystal idx to atom idx

    Returns:
        prediction (nn.Variable): shape (N, )
          Atom hidden features after convolution

    """
    atom_fea = self.embedding(atom_fea)
    for conv_func in self.convs:
        atom_fea = conv_func(atom_fea, nbr_fea, nbr_fea_idx)
    crys_fea = self.pooling(atom_fea, crystal_atom_idx)
    crys_fea = self.conv_to_fc(self.conv_to_fc_softplus(crys_fea))
    crys_fea = self.conv_to_fc_softplus(crys_fea)
    if self.classification:
        crys_fea = self.dropout(crys_fea)
    if hasattr(self, "fcs") and hasattr(self, "softpluses"):
        for fc, softplus in zip(self.fcs, self.softpluses):
            crys_fea = softplus(fc(crys_fea))
    out = self.fc_out(crys_fea)
    if self.classification:
        out = self.logsoftmax(out)
    return out, crys_fea

pooling(atom_fea, crystal_atom_idx)

Aggregate atom features into crystal-level features by mean pooling.

Parameters:

Name Type Description Default
atom_fea Tensor

shape (N, atom_fea_len) Atom embeddings for all atoms in the batch.

required
crystal_atom_idx list[LongTensor]

crystal_atom_idx[i] contains the indices of atoms belonging to the i-th crystal. The concatenated indices must cover every atom exactly once.

required

Returns:

Name Type Description
mean_fea Tensor

shape (n_crystals, atom_fea_len) Mean-pooled crystal embeddings.

Source code in cgcnn2/model.py
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
def pooling(
    self, atom_fea: torch.Tensor, crystal_atom_idx: list[torch.LongTensor]
) -> torch.Tensor:
    """
    Aggregate atom features into crystal-level features by mean pooling.

    Args:
        atom_fea (torch.Tensor): shape (N, atom_fea_len)
            Atom embeddings for all atoms in the batch.
        crystal_atom_idx (list[torch.LongTensor]):
            crystal_atom_idx[i] contains the indices of atoms belonging
            to the i-th crystal.  The concatenated indices must cover
            every atom exactly once.

    Returns:
        mean_fea (torch.Tensor): shape (n_crystals, atom_fea_len)
            Mean-pooled crystal embeddings.
    """
    assert sum(len(idx) for idx in crystal_atom_idx) == atom_fea.shape[0]
    mean_fea = torch.stack(
        [torch.mean(atom_fea[idx], dim=0) for idx in crystal_atom_idx], dim=0
    )
    return mean_fea

Utility Function

cgcnn2.util

Normalizer

Normalizes a PyTorch tensor and allows restoring it later.

This class keeps track of the mean and standard deviation of a tensor and provides methods to normalize and denormalize tensors using these statistics.

Source code in cgcnn2/util.py
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
class Normalizer:
    """
    Normalizes a PyTorch tensor and allows restoring it later.

    This class keeps track of the mean and standard deviation of a tensor and provides methods
    to normalize and denormalize tensors using these statistics.
    """

    def __init__(self, tensor: torch.Tensor) -> None:
        """
        Initialize the Normalizer with a sample tensor to calculate mean and standard deviation.

        Args:
            tensor (torch.Tensor): Sample tensor to compute mean and standard deviation.
        """
        self.mean: torch.Tensor = torch.mean(tensor)
        self.std: torch.Tensor = torch.std(tensor)

    def norm(self, tensor: torch.Tensor) -> torch.Tensor:
        """
        Normalize a tensor using the stored mean and standard deviation.

        Args:
            tensor (torch.Tensor): Tensor to normalize.

        Returns:
            torch.Tensor: Normalized tensor.
        """
        return (tensor - self.mean) / self.std

    def denorm(self, normed_tensor: torch.Tensor) -> torch.Tensor:
        """
        Denormalize a tensor using the stored mean and standard deviation.

        Args:
            normed_tensor (torch.Tensor): Normalized tensor to denormalize.

        Returns:
            torch.Tensor: Denormalized tensor.
        """
        return normed_tensor * self.std + self.mean

    def state_dict(self) -> dict[str, torch.Tensor]:
        """
        Returns the state dictionary containing the mean and standard deviation.

        Returns:
            dict[str, torch.Tensor]: State dictionary.
        """
        return {"mean": self.mean, "std": self.std}

    def load_state_dict(self, state_dict: dict[str, torch.Tensor]) -> None:
        """
        Loads the mean and standard deviation from a state dictionary.

        Args:
            state_dict (dict[str, torch.Tensor]): State dictionary containing 'mean' and 'std'.
        """
        self.mean = state_dict["mean"]
        self.std = state_dict["std"]

__init__(tensor)

Initialize the Normalizer with a sample tensor to calculate mean and standard deviation.

Parameters:

Name Type Description Default
tensor Tensor

Sample tensor to compute mean and standard deviation.

required
Source code in cgcnn2/util.py
433
434
435
436
437
438
439
440
441
def __init__(self, tensor: torch.Tensor) -> None:
    """
    Initialize the Normalizer with a sample tensor to calculate mean and standard deviation.

    Args:
        tensor (torch.Tensor): Sample tensor to compute mean and standard deviation.
    """
    self.mean: torch.Tensor = torch.mean(tensor)
    self.std: torch.Tensor = torch.std(tensor)

denorm(normed_tensor)

Denormalize a tensor using the stored mean and standard deviation.

Parameters:

Name Type Description Default
normed_tensor Tensor

Normalized tensor to denormalize.

required

Returns:

Type Description
Tensor

torch.Tensor: Denormalized tensor.

Source code in cgcnn2/util.py
455
456
457
458
459
460
461
462
463
464
465
def denorm(self, normed_tensor: torch.Tensor) -> torch.Tensor:
    """
    Denormalize a tensor using the stored mean and standard deviation.

    Args:
        normed_tensor (torch.Tensor): Normalized tensor to denormalize.

    Returns:
        torch.Tensor: Denormalized tensor.
    """
    return normed_tensor * self.std + self.mean

load_state_dict(state_dict)

Loads the mean and standard deviation from a state dictionary.

Parameters:

Name Type Description Default
state_dict dict[str, Tensor]

State dictionary containing 'mean' and 'std'.

required
Source code in cgcnn2/util.py
476
477
478
479
480
481
482
483
484
def load_state_dict(self, state_dict: dict[str, torch.Tensor]) -> None:
    """
    Loads the mean and standard deviation from a state dictionary.

    Args:
        state_dict (dict[str, torch.Tensor]): State dictionary containing 'mean' and 'std'.
    """
    self.mean = state_dict["mean"]
    self.std = state_dict["std"]

norm(tensor)

Normalize a tensor using the stored mean and standard deviation.

Parameters:

Name Type Description Default
tensor Tensor

Tensor to normalize.

required

Returns:

Type Description
Tensor

torch.Tensor: Normalized tensor.

Source code in cgcnn2/util.py
443
444
445
446
447
448
449
450
451
452
453
def norm(self, tensor: torch.Tensor) -> torch.Tensor:
    """
    Normalize a tensor using the stored mean and standard deviation.

    Args:
        tensor (torch.Tensor): Tensor to normalize.

    Returns:
        torch.Tensor: Normalized tensor.
    """
    return (tensor - self.mean) / self.std

state_dict()

Returns the state dictionary containing the mean and standard deviation.

Returns:

Type Description
dict[str, Tensor]

dict[str, torch.Tensor]: State dictionary.

Source code in cgcnn2/util.py
467
468
469
470
471
472
473
474
def state_dict(self) -> dict[str, torch.Tensor]:
    """
    Returns the state dictionary containing the mean and standard deviation.

    Returns:
        dict[str, torch.Tensor]: State dictionary.
    """
    return {"mean": self.mean, "std": self.std}

cgcnn_descriptor(model, loader, device, verbose)

This function takes a pre-trained CGCNN model and a dataset, runs inference to generate predictions and features from the last layer, and returns the predictions and features. It is not necessary to have target values for the predicted set.

Parameters:

Name Type Description Default
model Module

The trained CGCNN model.

required
loader DataLoader

DataLoader for the dataset.

required
device str

The device ('cuda' or 'cpu') where the model will be run.

required
verbose int

The verbosity level of the output.

required

Returns:

Name Type Description
tuple tuple[list[float], list[Tensor]]

A tuple containing: - list: Model predictions - list: Crystal features from the last layer

Notes

This function is intended for use in programmatic downstream analysis, where the user wants to continue downstream analysis using predictions or features (descriptors) generated by the model. For the command-line interface, consider using the cgcnn_pr script instead.

Source code in cgcnn2/util.py
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
def cgcnn_descriptor(
    model: torch.nn.Module,
    loader: torch.utils.data.DataLoader,
    device: str,
    verbose: int,
) -> tuple[list[float], list[torch.Tensor]]:
    """
    This function takes a pre-trained CGCNN model and a dataset, runs inference
    to generate predictions and features from the last layer, and returns the
    predictions and features. It is not necessary to have target values for the
    predicted set.

    Args:
        model (torch.nn.Module): The trained CGCNN model.
        loader (torch.utils.data.DataLoader): DataLoader for the dataset.
        device (str): The device ('cuda' or 'cpu') where the model will be run.
        verbose (int): The verbosity level of the output.

    Returns:
        tuple: A tuple containing:
            - list: Model predictions
            - list: Crystal features from the last layer

    Notes:
        This function is intended for use in programmatic downstream analysis,
        where the user wants to continue downstream analysis using predictions or
        features (descriptors) generated by the model. For the command-line interface,
        consider using the cgcnn_pr script instead.
    """

    model.eval()
    targets_list = []
    outputs_list = []
    crys_feas_list = []
    index = 0

    with torch.inference_mode():
        for input, target, cif_id in loader:
            atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx = input
            atom_fea = atom_fea.to(device)
            nbr_fea = nbr_fea.to(device)
            nbr_fea_idx = nbr_fea_idx.to(device)
            crystal_atom_idx = [idx_map.to(device) for idx_map in crystal_atom_idx]
            target = target.to(device)

            output, crys_fea = model(atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx)

            targets_list.extend(target.cpu().numpy().ravel().tolist())
            outputs_list.extend(output.cpu().numpy().ravel().tolist())
            crys_feas_list.append(crys_fea.cpu().numpy())

            index += 1

            # Extract the actual values from cif_id and output tensor
            cif_id_value = cif_id[0] if cif_id and isinstance(cif_id, list) else cif_id
            prediction_value = output.item() if output.numel() == 1 else output.tolist()

            if verbose >= 4:
                logging.info(
                    f"index: {index} | cif id: {cif_id_value} | prediction: {prediction_value}"
                )

    return outputs_list, crys_feas_list

cgcnn_pred(model_path, full_set, verbose=4, cuda=False, num_workers=0)

This function takes the path to a pre-trained CGCNN model and a dataset, runs inference to generate predictions, and returns the predictions. It is not necessary to have target values for the predicted set.

Parameters:

Name Type Description Default
model_path str

Path to the file containing the pre-trained model parameters.

required
full_set str

Path to the directory containing all CIF files for the dataset.

required
verbose int

Verbosity level of the output. Defaults to 4.

4
cuda bool

Whether to use CUDA. Defaults to False.

False
num_workers int

Number of subprocesses for data loading. Defaults to 0.

0

Returns:

Name Type Description
tuple tuple[list[float], list[Tensor]]

A tuple containing: - list: Model predictions - list: Features from the last layer

Notes

This function is intended for use in programmatic downstream analysis, where the user wants to continue downstream analysis using predictions or features (descriptors) generated by the model. For the command-line interface, consider using the cgcnn_pr script instead.

Source code in cgcnn2/util.py
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
def cgcnn_pred(
    model_path: str,
    full_set: str,
    verbose: int = 4,
    cuda: bool = False,
    num_workers: int = 0,
) -> tuple[list[float], list[torch.Tensor]]:
    """
    This function takes the path to a pre-trained CGCNN model and a dataset,
    runs inference to generate predictions, and returns the predictions. It is
    not necessary to have target values for the predicted set.

    Args:
        model_path (str): Path to the file containing the pre-trained model parameters.
        full_set (str): Path to the directory containing all CIF files for the dataset.
        verbose (int, optional): Verbosity level of the output. Defaults to 4.
        cuda (bool, optional): Whether to use CUDA. Defaults to False.
        num_workers (int, optional): Number of subprocesses for data loading. Defaults to 0.

    Returns:
        tuple: A tuple containing:
            - list: Model predictions
            - list: Features from the last layer

    Notes:
        This function is intended for use in programmatic downstream analysis,
        where the user wants to continue downstream analysis using predictions or
        features (descriptors) generated by the model. For the command-line interface,
        consider using the cgcnn_pr script instead.
    """
    if not os.path.isfile(model_path):
        raise FileNotFoundError(f"=> No model params found at '{model_path}'")

    total_dataset = CIFData_NoTarget(full_set)

    checkpoint = torch.load(
        model_path,
        map_location=lambda storage, loc: storage if not cuda else None,
        weights_only=False,
    )
    structures, _, _ = total_dataset[0]
    orig_atom_fea_len = structures[0].shape[-1]
    nbr_fea_len = structures[1].shape[-1]
    model_args = argparse.Namespace(**checkpoint["args"])
    model = CrystalGraphConvNet(
        orig_atom_fea_len,
        nbr_fea_len,
        atom_fea_len=model_args.atom_fea_len,
        n_conv=model_args.n_conv,
        h_fea_len=model_args.h_fea_len,
        n_h=model_args.n_h,
    )
    if cuda:
        model.cuda()

    normalizer = Normalizer(torch.zeros(3))
    normalizer.load_state_dict(checkpoint["normalizer"])
    model.load_state_dict(checkpoint["state_dict"])

    if verbose >= 3:
        print_checkpoint_info(checkpoint, model_path)

    device = "cuda" if cuda else "cpu"
    model.to(device).eval()

    full_loader = DataLoader(
        total_dataset,
        batch_size=1,
        shuffle=False,
        num_workers=num_workers,
        collate_fn=collate_pool,
        pin_memory=cuda,
    )

    pred, last_layer = cgcnn_descriptor(model, full_loader, device, verbose)

    return pred, last_layer

cgcnn_test(model, loader, device, plot_file='parity_plot.png', results_file='results.csv', axis_limits=None, **kwargs)

This function takes a pre-trained CGCNN model and a test dataset, runs inference to generate predictions, creates a parity plot comparing predicted versus actual values, and writes the results to a CSV file.

Parameters:

Name Type Description Default
model Module

The pre-trained CGCNN model.

required
loader DataLoader

DataLoader for the dataset.

required
device str

The device ('cuda' or 'cpu') where the model will be run.

required
plot_file str

File path for saving the parity plot. Defaults to 'parity_plot.png'.

'parity_plot.png'
results_file str

File path for saving results as CSV. Defaults to 'results.csv'.

'results.csv'
axis_limits list

Limits for x-axis (Actual values) of the parity plot. Defaults to None.

None
**kwargs Any

Additional keyword arguments: xlabel (str): x-axis label for the parity plot. Defaults to "Actual". ylabel (str): y-axis label for the parity plot. Defaults to "Predicted".

{}
Notes

This function is intended for use in a command-line interface, providing direct output of results. For programmatic downstream analysis, consider using the API functions instead, i.e. cgcnn_pred and cgcnn_descriptor.

Source code in cgcnn2/util.py
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
def cgcnn_test(
    model: torch.nn.Module,
    loader: torch.utils.data.DataLoader,
    device: str,
    plot_file: str = "parity_plot.png",
    results_file: str = "results.csv",
    axis_limits: list[float] | None = None,
    **kwargs: Any,
) -> None:
    """
    This function takes a pre-trained CGCNN model and a test dataset, runs
    inference to generate predictions, creates a parity plot comparing predicted
    versus actual values, and writes the results to a CSV file.

    Args:
        model (torch.nn.Module): The pre-trained CGCNN model.
        loader (torch.utils.data.DataLoader): DataLoader for the dataset.
        device (str): The device ('cuda' or 'cpu') where the model will be run.
        plot_file (str, optional): File path for saving the parity plot. Defaults to 'parity_plot.png'.
        results_file (str, optional): File path for saving results as CSV. Defaults to 'results.csv'.
        axis_limits (list, optional): Limits for x-axis (Actual values) of the parity plot. Defaults to None.
        **kwargs: Additional keyword arguments:
            xlabel (str): x-axis label for the parity plot. Defaults to "Actual".
            ylabel (str): y-axis label for the parity plot. Defaults to "Predicted".

    Notes:
        This function is intended for use in a command-line interface, providing
        direct output of results. For programmatic downstream analysis, consider
        using the API functions instead, i.e. cgcnn_pred and cgcnn_descriptor.
    """

    # Extract optional plot labels from kwargs, with defaults
    xlabel = kwargs.get("xlabel", "Actual")
    ylabel = kwargs.get("ylabel", "Predicted")

    model.eval()
    targets_list = []
    outputs_list = []
    cif_ids = []

    with torch.inference_mode():
        for input_batch, target, cif_id in loader:
            atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx = input_batch
            atom_fea = atom_fea.to(device)
            nbr_fea = nbr_fea.to(device)
            nbr_fea_idx = nbr_fea_idx.to(device)
            crystal_atom_idx = [idx_map.to(device) for idx_map in crystal_atom_idx]
            target = target.to(device)
            output, _ = model(atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx)

            targets_list.extend(target.cpu().numpy().ravel().tolist())
            outputs_list.extend(output.cpu().numpy().ravel().tolist())
            cif_ids.extend(cif_id)

    targets_array = np.array(targets_list)
    outputs_array = np.array(outputs_list)

    # MSE and R2 Score
    mse = np.mean((targets_array - outputs_array) ** 2)
    ss_res = np.sum((targets_array - outputs_array) ** 2)
    ss_tot = np.sum((targets_array - np.mean(targets_array)) ** 2)
    r2 = 1 - ss_res / ss_tot
    logging.info(f"MSE: {mse:.4f}, R2 Score: {r2:.4f}")

    # Save results to CSV
    sorted_rows = sorted(zip(cif_ids, targets_list, outputs_list), key=lambda x: x[0])
    with open(results_file, "w", newline="") as file:
        writer = csv.writer(file)
        writer.writerow(["cif_id", "Actual", "Predicted"])
        writer.writerows(sorted_rows)
    logging.info(f"Prediction results have been saved to {results_file}")

    # Create parity plot
    df_full = pd.DataFrame({"Actual": targets_list, "Predicted": outputs_list})
    _make_and_save_parity(df_full, xlabel, ylabel, plot_file)
    logging.info(f"Parity plot has been saved to {plot_file}")

    # If axis limits are provided, save the csv file with the specified limits
    if axis_limits:
        df_clip = df_full[
            (df_full["Actual"] >= axis_limits[0])
            & (df_full["Actual"] <= axis_limits[1])
        ]
        clipped_file = plot_file.replace(
            ".png", f"_axis_limits_{axis_limits[0]}_{axis_limits[1]}.png"
        )
        _make_and_save_parity(df_clip, xlabel, ylabel, clipped_file)
        logging.info(
            f"Parity plot clipped to {axis_limits} on Actual has been saved to {clipped_file}"
        )

get_local_version()

Retrieves the version of the project from the pyproject.toml file.

Returns:

Name Type Description
version str

The version of the project.

Source code in cgcnn2/util.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
def get_local_version() -> str:
    """
    Retrieves the version of the project from the pyproject.toml file.

    Returns:
        version (str): The version of the project.
    """
    project_root = Path(__file__).parents[2]
    toml_path = project_root / "pyproject.toml"
    try:
        with toml_path.open("rb") as f:
            data = tomllib.load(f)
            version = data["project"]["version"]
        return version
    except Exception:
        return "unknown"

get_lr(optimizer)

Extracts learning rates from a PyTorch optimizer.

Parameters:

Name Type Description Default
optimizer Optimizer

The PyTorch optimizer to extract learning rates from.

required

Returns:

Name Type Description
learning_rates list[float]

A list of learning rates for each parameter group.

Source code in cgcnn2/util.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def get_lr(optimizer: torch.optim.Optimizer) -> list[float]:
    """
    Extracts learning rates from a PyTorch optimizer.

    Args:
        optimizer (torch.optim.Optimizer): The PyTorch optimizer to extract learning rates from.

    Returns:
        learning_rates (list[float]): A list of learning rates for each parameter group.
    """

    learning_rates = []
    for param_group in optimizer.param_groups:
        learning_rates.append(param_group["lr"])
    return learning_rates

id_prop_gen(cif_dir)

Generates a CSV file containing IDs and properties of CIF files.

Parameters:

Name Type Description Default
cif_dir str

Directory containing the CIF files.

required
Source code in cgcnn2/util.py
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def id_prop_gen(cif_dir: str) -> None:
    """Generates a CSV file containing IDs and properties of CIF files.

    Args:
        cif_dir (str): Directory containing the CIF files.
    """

    cif_list = glob.glob(f"{cif_dir}/*.cif")

    id_prop_cif = pd.DataFrame(
        {
            "id": [os.path.basename(cif).split(".")[0] for cif in cif_list],
            "prop": [0 for _ in range(len(cif_list))],
        }
    )

    id_prop_cif.to_csv(
        f"{cif_dir}/id_prop.csv",
        index=False,
        header=False,
    )

output_id_gen()

Generates a unique output identifier based on current date and time.

Returns:

Name Type Description
folder_name str

A string in format 'output_mmdd_HHMM' for current date/time.

Source code in cgcnn2/util.py
62
63
64
65
66
67
68
69
70
71
72
73
74
def output_id_gen() -> str:
    """
    Generates a unique output identifier based on current date and time.

    Returns:
        folder_name (str): A string in format 'output_mmdd_HHMM' for current date/time.
    """

    now = datetime.now()
    timestamp = now.strftime("%m%d_%H%M")
    folder_name = f"output_{timestamp}"

    return folder_name

print_checkpoint_info(checkpoint, model_path)

Prints the checkpoint information.

Parameters:

Name Type Description Default
checkpoint dict[str, Any]

The checkpoint dictionary.

required
model_path str

The path to the model file.

required
Source code in cgcnn2/util.py
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
def print_checkpoint_info(checkpoint: dict[str, Any], model_path: str) -> None:
    """
    Prints the checkpoint information.

    Args:
        checkpoint (dict[str, Any]): The checkpoint dictionary.
        model_path (str): The path to the model file.
    """
    epoch = checkpoint.get("epoch", "N/A")
    mse = checkpoint.get("best_mse_error")
    mae = checkpoint.get("best_mae_error")

    metrics = []
    if mse is not None:
        metrics.append(f"MSE={mse:.4f}")
    if mae is not None:
        metrics.append(f"MAE={mae:.4f}")

    metrics_str = ", ".join(metrics) if metrics else "N/A"

    logging.info(
        f"=> Loaded model from '{model_path}' (epoch {epoch}, validation {metrics_str})"
    )

seed_everything(seed)

Seeds the random number generators for Python, NumPy, PyTorch, and PyTorch CUDA.

Parameters:

Name Type Description Default
seed int

The seed value to use for random number generation.

required
Source code in cgcnn2/util.py
512
513
514
515
516
517
518
519
520
521
522
523
524
525
def seed_everything(seed: int) -> None:
    """
    Seeds the random number generators for Python, NumPy, PyTorch, and PyTorch CUDA.

    Args:
        seed (int): The seed value to use for random number generation.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

setup_logging()

Sets up logging for the project.

Source code in cgcnn2/util.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def setup_logging() -> None:
    """
    Sets up logging for the project.
    """
    logging.basicConfig(
        stream=sys.stdout,
        level=logging.INFO,
        format="%(asctime)s.%(msecs)03d %(levelname)s: %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )
    logging.captureWarnings(True)

    logging.info(f"cgcnn2 version: {cgcnn2.__version__}")
    logging.info(f"cuda version: {torch.version.cuda}")
    logging.info(f"torch version: {torch.__version__}")

unique_structures_clean(dataset_dir, delete_duplicates=False)

Checks for duplicate (structurally equivalent) structures in a directory of CIF files using pymatgen's StructureMatcher and returns the count of unique structures.

Parameters:

Name Type Description Default
dataset_dir str

The path to the dataset containing CIF files.

required
delete_duplicates bool

Whether to delete the duplicate structures, default is False.

False

Returns:

Name Type Description
grouped list

A list of lists, where each sublist contains structurally equivalent structures.

Source code in cgcnn2/util.py
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
def unique_structures_clean(dataset_dir, delete_duplicates=False):
    """
    Checks for duplicate (structurally equivalent) structures in a directory
    of CIF files using pymatgen's StructureMatcher and returns the count
    of unique structures.

    Args:
        dataset_dir (str): The path to the dataset containing CIF files.
        delete_duplicates (bool): Whether to delete the duplicate structures, default is False.

    Returns:
        grouped (list): A list of lists, where each sublist contains structurally equivalent structures.
    """
    cif_files = [f for f in os.listdir(dataset_dir) if f.endswith(".cif")]
    structures = []
    filenames = []

    for fname in cif_files:
        full_path = os.path.join(dataset_dir, fname)
        structures.append(Structure.from_file(full_path))
        filenames.append(fname)

    id_to_fname = {id(s): fn for s, fn in zip(structures, filenames)}

    matcher = StructureMatcher()
    grouped = matcher.group_structures(structures)

    grouped_fnames = [[id_to_fname[id(s)] for s in group] for group in grouped]

    if delete_duplicates:
        for file_group in grouped_fnames:
            # keep the first file, delete the rest
            for dup_fname in file_group[1:]:
                os.remove(os.path.join(dataset_dir, dup_fname))

    return grouped_fnames