Skip to content

API Reference

cgcnn2.data

AtomCustomJSONInitializer

Bases: AtomInitializer

Initialize atom feature vectors using a JSON file, which is a python dictionary mapping from element number to a list representing the feature vector of the element.

Parameters:

Name Type Description Default
elem_embedding_file str

The path to the .json file

required
Source code in cgcnn2/data.py
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
class AtomCustomJSONInitializer(AtomInitializer):
    """
    Initialize atom feature vectors using a JSON file, which is a python
    dictionary mapping from element number to a list representing the
    feature vector of the element.

    Args:
        elem_embedding_file (str): The path to the `.json` file
    """

    def __init__(self, elem_embedding_file):
        """
        Initialize atom feature embeddings from a JSON file mapping element numbers to feature vectors.

        Parameters:
            elem_embedding_file (str): Path to a JSON file where keys are element numbers and values are feature vectors.
        """
        with open(elem_embedding_file) as f:
            elem_embedding = json.load(f)
        elem_embedding = {int(key): value for key, value in elem_embedding.items()}
        atom_types = set(elem_embedding.keys())
        super().__init__(atom_types)
        for key, value in elem_embedding.items():
            self._embedding[key] = np.array(value, dtype=float)

__init__(elem_embedding_file)

Initialize atom feature embeddings from a JSON file mapping element numbers to feature vectors.

Parameters:

Name Type Description Default
elem_embedding_file str

Path to a JSON file where keys are element numbers and values are feature vectors.

required
Source code in cgcnn2/data.py
183
184
185
186
187
188
189
190
191
192
193
194
195
196
def __init__(self, elem_embedding_file):
    """
    Initialize atom feature embeddings from a JSON file mapping element numbers to feature vectors.

    Parameters:
        elem_embedding_file (str): Path to a JSON file where keys are element numbers and values are feature vectors.
    """
    with open(elem_embedding_file) as f:
        elem_embedding = json.load(f)
    elem_embedding = {int(key): value for key, value in elem_embedding.items()}
    atom_types = set(elem_embedding.keys())
    super().__init__(atom_types)
    for key, value in elem_embedding.items():
        self._embedding[key] = np.array(value, dtype=float)

AtomInitializer

Base class for initializing the vector representation for atoms. Use one AtomInitializer per dataset.

Source code in cgcnn2/data.py
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
class AtomInitializer:
    """
    Base class for initializing the vector representation for atoms.
    Use one `AtomInitializer` per dataset.
    """

    def __init__(self, atom_types):
        """
        Initialize the atom types and embedding dictionary.

        Args:
            atom_types (set): A set of unique atom types in the dataset.
        """
        self.atom_types = set(atom_types)
        self._embedding = {}

    def get_atom_fea(self, atom_type):
        """
        Get the vector representation for an atom type.

        Args:
            atom_type (str): The type of atom to get the vector representation for.
        """
        assert atom_type in self.atom_types
        return self._embedding[atom_type]

    def load_state_dict(self, state_dict):
        """
        Load the state dictionary for the atom initializer.

        Args:
            state_dict (dict): The state dictionary to load.
        """
        self._embedding = state_dict
        self.atom_types = set(self._embedding.keys())
        self._decodedict = {
            idx: atom_type for atom_type, idx in self._embedding.items()
        }

    def state_dict(self) -> dict:
        """
        Get the state dictionary for the atom initializer.

        Returns:
            dict: The state dictionary.
        """
        return self._embedding

    def decode(self, idx: int) -> str:
        """
        Decode an index to an atom type.

        Args:
            idx (int): The index to decode.

        Returns:
            str: The decoded atom type.
        """
        if not hasattr(self, "_decodedict"):
            self._decodedict = {
                idx: atom_type for atom_type, idx in self._embedding.items()
            }
        return self._decodedict[idx]

__init__(atom_types)

Initialize the atom types and embedding dictionary.

Parameters:

Name Type Description Default
atom_types set

A set of unique atom types in the dataset.

required
Source code in cgcnn2/data.py
114
115
116
117
118
119
120
121
122
def __init__(self, atom_types):
    """
    Initialize the atom types and embedding dictionary.

    Args:
        atom_types (set): A set of unique atom types in the dataset.
    """
    self.atom_types = set(atom_types)
    self._embedding = {}

decode(idx)

Decode an index to an atom type.

Parameters:

Name Type Description Default
idx int

The index to decode.

required

Returns:

Name Type Description
str str

The decoded atom type.

Source code in cgcnn2/data.py
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
def decode(self, idx: int) -> str:
    """
    Decode an index to an atom type.

    Args:
        idx (int): The index to decode.

    Returns:
        str: The decoded atom type.
    """
    if not hasattr(self, "_decodedict"):
        self._decodedict = {
            idx: atom_type for atom_type, idx in self._embedding.items()
        }
    return self._decodedict[idx]

get_atom_fea(atom_type)

Get the vector representation for an atom type.

Parameters:

Name Type Description Default
atom_type str

The type of atom to get the vector representation for.

required
Source code in cgcnn2/data.py
124
125
126
127
128
129
130
131
132
def get_atom_fea(self, atom_type):
    """
    Get the vector representation for an atom type.

    Args:
        atom_type (str): The type of atom to get the vector representation for.
    """
    assert atom_type in self.atom_types
    return self._embedding[atom_type]

load_state_dict(state_dict)

Load the state dictionary for the atom initializer.

Parameters:

Name Type Description Default
state_dict dict

The state dictionary to load.

required
Source code in cgcnn2/data.py
134
135
136
137
138
139
140
141
142
143
144
145
def load_state_dict(self, state_dict):
    """
    Load the state dictionary for the atom initializer.

    Args:
        state_dict (dict): The state dictionary to load.
    """
    self._embedding = state_dict
    self.atom_types = set(self._embedding.keys())
    self._decodedict = {
        idx: atom_type for atom_type, idx in self._embedding.items()
    }

state_dict()

Get the state dictionary for the atom initializer.

Returns:

Name Type Description
dict dict

The state dictionary.

Source code in cgcnn2/data.py
147
148
149
150
151
152
153
154
def state_dict(self) -> dict:
    """
    Get the state dictionary for the atom initializer.

    Returns:
        dict: The state dictionary.
    """
    return self._embedding

CIFData

Bases: Dataset

The CIFData dataset is a wrapper for a dataset where the crystal structures are stored in the form of CIF files.

id_prop.csv: a CSV file with two columns. The first column records a unique ID for each crystal, and the second column records the value of target property.

atom_init.json: a JSON file that stores the initialization vector for each element.

ID.cif: a CIF file that records the crystal structure, where ID is the unique ID for the crystal.

Parameters:

Name Type Description Default
root_dir str

The path to the root directory of the dataset

required
max_num_nbr int

The maximum number of neighbors while constructing the crystal graph

12
radius float

The cutoff radius for searching neighbors

8
dmin float

The minimum distance for constructing GaussianDistance

0
step float

The step size for constructing GaussianDistance

0.2
cache_size int | None

The size of the lru cache for the dataset. Default is None.

None
random_seed int

Random seed for shuffling the dataset

123

Returns:

Name Type Description
atom_fea Tensor

shape (n_i, atom_fea_len)

nbr_fea Tensor

shape (n_i, M, nbr_fea_len)

nbr_fea_idx LongTensor

shape (n_i, M)

target Tensor

shape (1, )

cif_id str or int

Unique ID for the crystal

Source code in cgcnn2/data.py
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
class CIFData(Dataset):
    """
    The CIFData dataset is a wrapper for a dataset where the crystal structures
    are stored in the form of CIF files.

    `id_prop.csv`: a CSV file with two columns. The first column records a
    unique ID for each crystal, and the second column records the value of
    target property.

    `atom_init.json`: a JSON file that stores the initialization vector for each
    element.

    `ID.cif`: a CIF file that records the crystal structure, where ID is the
    unique ID for the crystal.

    Args:
        root_dir (str): The path to the root directory of the dataset
        max_num_nbr (int): The maximum number of neighbors while constructing the crystal graph
        radius (float): The cutoff radius for searching neighbors
        dmin (float): The minimum distance for constructing GaussianDistance
        step (float): The step size for constructing GaussianDistance
        cache_size (int | None): The size of the lru cache for the dataset. Default is None.
        random_seed (int): Random seed for shuffling the dataset

    Returns:
        atom_fea (torch.Tensor): shape (n_i, atom_fea_len)
        nbr_fea (torch.Tensor): shape (n_i, M, nbr_fea_len)
        nbr_fea_idx (torch.LongTensor): shape (n_i, M)
        target (torch.Tensor): shape (1, )
        cif_id (str or int): Unique ID for the crystal
    """

    def __init__(
        self,
        root_dir,
        max_num_nbr=12,
        radius=8,
        dmin=0,
        step=0.2,
        cache_size=None,
        random_seed=123,
    ):
        self.root_dir = root_dir
        self.max_num_nbr, self.radius = max_num_nbr, radius
        assert os.path.exists(root_dir), "root_dir does not exist!"
        id_prop_file = os.path.join(self.root_dir, "id_prop.csv")
        assert os.path.exists(id_prop_file), "id_prop.csv does not exist!"
        with open(id_prop_file) as f:
            reader = csv.reader(f)
            self.id_prop_data = [row for row in reader]
        random.seed(random_seed)
        atom_init_file = os.path.join(self.root_dir, "atom_init.json")
        assert os.path.exists(atom_init_file), "atom_init.json does not exist!"
        self.ari = AtomCustomJSONInitializer(atom_init_file)
        self.gdf = GaussianDistance(dmin=dmin, dmax=self.radius, step=step)
        self._raw_load_item = self._load_item_fast
        self.cache_size = cache_size
        self._configure_cache()

    def set_cache_size(self, cache_size: Optional[int]) -> None:
        """
        Change the LRU-cache capacity on the fly.

        Args:
            cache_size (int | None): The size of the cache to set, None for unlimited size. Default is None.
        """
        self.cache_size = cache_size
        if hasattr(self._cache_load, "cache_clear"):
            self._cache_load.cache_clear()
        self._configure_cache()

    def clear_cache(self) -> None:
        """
        Clear the current cache.
        """
        if hasattr(self._cache_load, "cache_clear"):
            self._cache_load.cache_clear()

    def __len__(self):
        return len(self.id_prop_data)

    def __getitem__(self, idx):
        return self._cache_load(idx)

    def _configure_cache(self) -> None:
        """
        Wrap `_raw_load_item` with an LRU cache.
        """
        if self.cache_size is None:
            self._cache_load = functools.lru_cache(maxsize=None)(self._raw_load_item)
        elif self.cache_size <= 0:
            self._cache_load = self._raw_load_item
        else:
            self._cache_load = functools.lru_cache(maxsize=self.cache_size)(
                self._raw_load_item
            )

    def _load_item(self, idx):
        cif_id, target = self.id_prop_data[idx]
        crystal = Structure.from_file(os.path.join(self.root_dir, cif_id + ".cif"))
        atom_fea = np.vstack(
            [
                self.ari.get_atom_fea(crystal[i].specie.number)
                for i in range(len(crystal))
            ]
        )
        atom_fea = torch.Tensor(atom_fea)
        all_nbrs = crystal.get_all_neighbors(self.radius, include_index=True)
        all_nbrs = [sorted(nbrs, key=lambda x: x[1]) for nbrs in all_nbrs]
        nbr_fea_idx, nbr_fea = [], []
        for nbr in all_nbrs:
            if len(nbr) < self.max_num_nbr:
                warnings.warn(
                    "{} not find enough neighbors to build graph. "
                    "If it happens frequently, consider increase "
                    "radius.".format(cif_id),
                    stacklevel=2,
                )
                nbr_fea_idx.append(
                    list(map(lambda x: x[2], nbr)) + [0] * (self.max_num_nbr - len(nbr))
                )
                nbr_fea.append(
                    list(map(lambda x: x[1], nbr))
                    + [self.radius + 1.0] * (self.max_num_nbr - len(nbr))
                )
            else:
                nbr_fea_idx.append(list(map(lambda x: x[2], nbr[: self.max_num_nbr])))
                nbr_fea.append(list(map(lambda x: x[1], nbr[: self.max_num_nbr])))
        nbr_fea_idx, nbr_fea = np.array(nbr_fea_idx), np.array(nbr_fea)
        nbr_fea = self.gdf.expand(nbr_fea)
        atom_fea = torch.Tensor(atom_fea)
        nbr_fea = torch.Tensor(nbr_fea)
        nbr_fea_idx = torch.LongTensor(nbr_fea_idx)
        target = torch.Tensor([float(target)])
        return (atom_fea, nbr_fea, nbr_fea_idx), target, cif_id

    def _load_item_fast(self, idx):
        cif_id, target = self.id_prop_data[idx]
        crystal = Structure.from_file(os.path.join(self.root_dir, cif_id + ".cif"))
        atom_fea = np.vstack(
            [
                self.ari.get_atom_fea(crystal[i].specie.number)
                for i in range(len(crystal))
            ]
        )
        atom_fea = torch.Tensor(atom_fea)
        center_idx, neigh_idx, _images, dists = crystal.get_neighbor_list(self.radius)
        n_sites = len(crystal)
        bucket = [[] for _ in range(n_sites)]
        for c, n, d in zip(center_idx, neigh_idx, dists):
            bucket[c].append((n, d))
        bucket = [sorted(lst, key=lambda x: x[1]) for lst in bucket]
        nbr_fea_idx, nbr_fea = [], []
        for lst in bucket:
            if len(lst) < self.max_num_nbr:
                warnings.warn(
                    f"{cif_id} not find enough neighbors to build graph. "
                    "If it happens frequently, consider increase "
                    "radius.",
                    stacklevel=2,
                )
            idxs = [t[0] for t in lst[: self.max_num_nbr]]
            dvec = [t[1] for t in lst[: self.max_num_nbr]]
            pad = self.max_num_nbr - len(idxs)
            nbr_fea_idx.append(idxs + [0] * pad)
            nbr_fea.append(dvec + [self.radius + 1.0] * pad)
        nbr_fea_idx = torch.as_tensor(np.array(nbr_fea_idx), dtype=torch.long)
        nbr_fea = self.gdf.expand(np.array(nbr_fea))
        nbr_fea = torch.Tensor(nbr_fea)
        target = torch.tensor([float(target)])
        return (atom_fea, nbr_fea, nbr_fea_idx), target, cif_id

clear_cache()

Clear the current cache.

Source code in cgcnn2/data.py
270
271
272
273
274
275
def clear_cache(self) -> None:
    """
    Clear the current cache.
    """
    if hasattr(self._cache_load, "cache_clear"):
        self._cache_load.cache_clear()

set_cache_size(cache_size)

Change the LRU-cache capacity on the fly.

Parameters:

Name Type Description Default
cache_size int | None

The size of the cache to set, None for unlimited size. Default is None.

required
Source code in cgcnn2/data.py
258
259
260
261
262
263
264
265
266
267
268
def set_cache_size(self, cache_size: Optional[int]) -> None:
    """
    Change the LRU-cache capacity on the fly.

    Args:
        cache_size (int | None): The size of the cache to set, None for unlimited size. Default is None.
    """
    self.cache_size = cache_size
    if hasattr(self._cache_load, "cache_clear"):
        self._cache_load.cache_clear()
    self._configure_cache()

CIFData_NoTarget

Bases: Dataset

The CIFData_NoTarget dataset is a wrapper for a dataset where the crystal structures are stored in the form of CIF files.

atom_init.json: a JSON file that stores the initialization vector for each element.

ID.cif: a CIF file that records the crystal structure, where ID is the unique ID for the crystal.

Parameters:

Name Type Description Default
root_dir str

The path to the root directory of the dataset

required
max_num_nbr int

The maximum number of neighbors while constructing the crystal graph

12
radius float

The cutoff radius for searching neighbors

8
dmin float

The minimum distance for constructing GaussianDistance

0
step float

The step size for constructing GaussianDistance

0.2
random_seed int

Random seed for shuffling the dataset

123

Returns:

Name Type Description
atom_fea Tensor

shape (n_i, atom_fea_len)

nbr_fea Tensor

shape (n_i, M, nbr_fea_len)

nbr_fea_idx LongTensor

shape (n_i, M)

target Tensor

shape (1, )

cif_id str or int

Unique ID for the crystal

Source code in cgcnn2/data.py
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
class CIFData_NoTarget(Dataset):
    """
    The CIFData_NoTarget dataset is a wrapper for a dataset where the crystal
    structures are stored in the form of CIF files.

    `atom_init.json`: a JSON file that stores the initialization vector for each
    element.

    `ID.cif`: a CIF file that records the crystal structure, where ID is the
    unique ID for the crystal.

    Args:
        root_dir (str): The path to the root directory of the dataset
        max_num_nbr (int): The maximum number of neighbors while constructing the crystal graph
        radius (float): The cutoff radius for searching neighbors
        dmin (float): The minimum distance for constructing GaussianDistance
        step (float): The step size for constructing GaussianDistance
        random_seed (int): Random seed for shuffling the dataset

    Returns:
        atom_fea (torch.Tensor): shape (n_i, atom_fea_len)
        nbr_fea (torch.Tensor): shape (n_i, M, nbr_fea_len)
        nbr_fea_idx (torch.LongTensor): shape (n_i, M)
        target (torch.Tensor): shape (1, )
        cif_id (str or int): Unique ID for the crystal
    """

    def __init__(
        self, root_dir, max_num_nbr=12, radius=8, dmin=0, step=0.2, random_seed=123
    ):
        self.root_dir = root_dir
        self.max_num_nbr, self.radius = max_num_nbr, radius
        assert os.path.exists(root_dir), "root_dir does not exist!"
        id_prop_data = []
        for file in os.listdir(root_dir):
            if file.endswith(".cif"):
                id_prop_data.append(file[:-4])
        id_prop_data = [(cif_id, 0) for cif_id in id_prop_data]
        id_prop_data.sort(key=lambda x: x[0])
        self.id_prop_data = id_prop_data
        random.seed(random_seed)
        atom_init_file = os.path.join(self.root_dir, "atom_init.json")
        assert os.path.exists(atom_init_file), "atom_init.json does not exist!"
        self.ari = AtomCustomJSONInitializer(atom_init_file)
        self.gdf = GaussianDistance(dmin=dmin, dmax=self.radius, step=step)

    def __len__(self):
        return len(self.id_prop_data)

    @functools.lru_cache(maxsize=None)  # Cache loaded structures
    def __getitem__(self, idx):
        cif_id, target = self.id_prop_data[idx]
        crystal = Structure.from_file(os.path.join(self.root_dir, cif_id + ".cif"))
        atom_fea = np.vstack(
            [
                self.ari.get_atom_fea(crystal[i].specie.number)
                for i in range(len(crystal))
            ]
        )
        atom_fea = torch.Tensor(atom_fea)
        all_nbrs = crystal.get_all_neighbors(self.radius, include_index=True)
        all_nbrs = [sorted(nbrs, key=lambda x: x[1]) for nbrs in all_nbrs]
        nbr_fea_idx, nbr_fea = [], []
        for nbr in all_nbrs:
            if len(nbr) < self.max_num_nbr:
                warnings.warn(
                    "{} not find enough neighbors to build graph. "
                    "If it happens frequently, consider increase "
                    "radius.".format(cif_id),
                    stacklevel=2,
                )
                nbr_fea_idx.append(
                    list(map(lambda x: x[2], nbr)) + [0] * (self.max_num_nbr - len(nbr))
                )
                nbr_fea.append(
                    list(map(lambda x: x[1], nbr))
                    + [self.radius + 1.0] * (self.max_num_nbr - len(nbr))
                )
            else:
                nbr_fea_idx.append(list(map(lambda x: x[2], nbr[: self.max_num_nbr])))
                nbr_fea.append(list(map(lambda x: x[1], nbr[: self.max_num_nbr])))
        nbr_fea_idx, nbr_fea = np.array(nbr_fea_idx), np.array(nbr_fea)
        nbr_fea = self.gdf.expand(nbr_fea)
        atom_fea = torch.Tensor(atom_fea)
        nbr_fea = torch.Tensor(nbr_fea)
        nbr_fea_idx = torch.LongTensor(nbr_fea_idx)
        target = torch.Tensor([float(target)])
        return (atom_fea, nbr_fea, nbr_fea_idx), target, cif_id

GaussianDistance

Expands the distance by Gaussian basis.

Unit: angstrom

Source code in cgcnn2/data.py
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
class GaussianDistance:
    """
    Expands the distance by Gaussian basis.

    Unit: angstrom
    """

    def __init__(self, dmin, dmax, step, var=None):
        """
        Args:
            dmin (float): Minimum interatomic distance (center of the first Gaussian).
            dmax (float): Maximum interatomic distance (center of the last Gaussian).
            step (float): Spacing between consecutive Gaussian centers.
            var (float, optional): Variance of each Gaussian. If None, defaults to step.
        """

        assert dmin < dmax
        assert dmax - dmin > step
        self.filter = np.arange(dmin, dmax + step, step)
        if var is None:
            var = step
        self.var = var

    def expand(self, distances):
        """
        Project each scalar distance onto a set of Gaussian basis functions.

        Args:
            distances (np.ndarray): An array of interatomic distances.

        Returns:
            expanded_distance (np.ndarray): An array where the last dimension contains the Gaussian basis values for each input distance.
        """

        expanded_distance = np.exp(
            -((distances[..., np.newaxis] - self.filter) ** 2) / self.var**2
        )
        return expanded_distance

__init__(dmin, dmax, step, var=None)

Parameters:

Name Type Description Default
dmin float

Minimum interatomic distance (center of the first Gaussian).

required
dmax float

Maximum interatomic distance (center of the last Gaussian).

required
step float

Spacing between consecutive Gaussian centers.

required
var float

Variance of each Gaussian. If None, defaults to step.

None
Source code in cgcnn2/data.py
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def __init__(self, dmin, dmax, step, var=None):
    """
    Args:
        dmin (float): Minimum interatomic distance (center of the first Gaussian).
        dmax (float): Maximum interatomic distance (center of the last Gaussian).
        step (float): Spacing between consecutive Gaussian centers.
        var (float, optional): Variance of each Gaussian. If None, defaults to step.
    """

    assert dmin < dmax
    assert dmax - dmin > step
    self.filter = np.arange(dmin, dmax + step, step)
    if var is None:
        var = step
    self.var = var

expand(distances)

Project each scalar distance onto a set of Gaussian basis functions.

Parameters:

Name Type Description Default
distances ndarray

An array of interatomic distances.

required

Returns:

Name Type Description
expanded_distance ndarray

An array where the last dimension contains the Gaussian basis values for each input distance.

Source code in cgcnn2/data.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
def expand(self, distances):
    """
    Project each scalar distance onto a set of Gaussian basis functions.

    Args:
        distances (np.ndarray): An array of interatomic distances.

    Returns:
        expanded_distance (np.ndarray): An array where the last dimension contains the Gaussian basis values for each input distance.
    """

    expanded_distance = np.exp(
        -((distances[..., np.newaxis] - self.filter) ** 2) / self.var**2
    )
    return expanded_distance

collate_pool(dataset_list)

Collate a list of data and return a batch for predicting crystal properties.

Parameters:

Name Type Description Default
dataset_list list of tuples

List of tuples for each data point. Each tuple contains:

required
atom_fea Tensor

shape (n_i, atom_fea_len) Atom features for each atom in the crystal

required
nbr_fea Tensor

shape (n_i, M, nbr_fea_len) Bond features for each atom's M neighbors

required
nbr_fea_idx LongTensor

shape (n_i, M) Indices of M neighbors of each atom

required
target Tensor

shape (1, ) Target value for prediction

required

Returns:

Name Type Description
batch_atom_fea Tensor

shape (N, orig_atom_fea_len) Atom features from atom type

batch_nbr_fea Tensor

shape (N, M, nbr_fea_len) Bond features of each atom's M neighbors

batch_nbr_fea_idx LongTensor

shape (N, M) Indices of M neighbors of each atom

crystal_atom_idx list of torch.LongTensor

length N0 Mapping from the crystal idx to atom idx

batch_target Tensor

shape (N, 1) Target value for prediction

batch_cif_ids list of str or int

Unique IDs for each crystal

Source code in cgcnn2/data.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def collate_pool(dataset_list):
    """
    Collate a list of data and return a batch for predicting crystal properties.

    Args:
        dataset_list (list of tuples): List of tuples for each data point. Each tuple contains:
        atom_fea (torch.Tensor): shape (n_i, atom_fea_len) Atom features for each atom in the crystal
        nbr_fea (torch.Tensor): shape (n_i, M, nbr_fea_len) Bond features for each atom's M neighbors
        nbr_fea_idx (torch.LongTensor): shape (n_i, M) Indices of M neighbors of each atom
        target (torch.Tensor): shape (1, ) Target value for prediction
        cif_id (str or int) Unique ID for the crystal

    Returns:
        batch_atom_fea (torch.Tensor): shape (N, orig_atom_fea_len) Atom features from atom type
        batch_nbr_fea (torch.Tensor): shape (N, M, nbr_fea_len) Bond features of each atom's M neighbors
        batch_nbr_fea_idx (torch.LongTensor): shape (N, M) Indices of M neighbors of each atom
        crystal_atom_idx (list of torch.LongTensor): length N0 Mapping from the crystal idx to atom idx
        batch_target (torch.Tensor): shape (N, 1) Target value for prediction
        batch_cif_ids (list of str or int): Unique IDs for each crystal
    """

    batch_atom_fea, batch_nbr_fea, batch_nbr_fea_idx = [], [], []
    crystal_atom_idx, batch_target = [], []
    batch_cif_ids = []
    base_idx = 0
    for (atom_fea, nbr_fea, nbr_fea_idx), target, cif_id in dataset_list:
        n_i = atom_fea.shape[0]  # number of atoms for this crystal
        batch_atom_fea.append(atom_fea)
        batch_nbr_fea.append(nbr_fea)
        batch_nbr_fea_idx.append(nbr_fea_idx + base_idx)
        new_idx = torch.LongTensor(np.arange(n_i) + base_idx)
        crystal_atom_idx.append(new_idx)
        batch_target.append(target)
        batch_cif_ids.append(cif_id)
        base_idx += n_i
    return (
        (
            torch.cat(batch_atom_fea, dim=0),
            torch.cat(batch_nbr_fea, dim=0),
            torch.cat(batch_nbr_fea_idx, dim=0),
            crystal_atom_idx,
        ),
        torch.stack(batch_target, dim=0),
        batch_cif_ids,
    )

full_set_split(full_set_dir, train_ratio, valid_ratio, train_force_dir=None, random_seed=0)

Split the full set into train, valid, and test sets into a temporary directory.

Parameters:

Name Type Description Default
full_set_dir str

The path to the full set

required
train_ratio float

The ratio of the training set

required
valid_ratio float

The ratio of the validation set

required
train_force_dir str

The path to the forced training set. Adding this will no longer keep the original split ratio.

None
random_seed int

The random seed for the split

0

Returns:

Name Type Description
train_dir str

The path to a temporary directory containing the train set

valid_dir str

The path to a temporary directory containing the valid set

test_dir str

The path to a temporary directory containing the test set

Source code in cgcnn2/data.py
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
def full_set_split(
    full_set_dir: str,
    train_ratio: float,
    valid_ratio: float,
    train_force_dir: str | None = None,
    random_seed: int = 0,
):
    """
    Split the full set into train, valid, and test sets into a temporary directory.

    Args:
        full_set_dir (str): The path to the full set
        train_ratio (float): The ratio of the training set
        valid_ratio (float): The ratio of the validation set
        train_force_dir (str): The path to the forced training set. Adding this will no longer keep the original split ratio.
        random_seed (int): The random seed for the split

    Returns:
        train_dir (str): The path to a temporary directory containing the train set
        valid_dir (str): The path to a temporary directory containing the valid set
        test_dir (str): The path to a temporary directory containing the test set
    """
    df = pd.read_csv(
        os.path.join(full_set_dir, "id_prop.csv"),
        header=None,
        names=["cif_id", "property"],
    )

    rng = np.random.RandomState(random_seed)
    df_shuffle = df.sample(frac=1.0, random_state=rng).reset_index(drop=True)

    n_total = len(df_shuffle)
    n_train = int(n_total * train_ratio)
    n_valid = int(n_total * valid_ratio)

    train_df = df_shuffle[:n_train]
    valid_df = df_shuffle[n_train : n_train + n_valid]
    test_df = df_shuffle[n_train + n_valid :]

    temp_train_dir = tempfile.mkdtemp()
    temp_valid_dir = tempfile.mkdtemp()
    temp_test_dir = tempfile.mkdtemp()

    atexit.register(shutil.rmtree, temp_train_dir, ignore_errors=True)
    atexit.register(shutil.rmtree, temp_valid_dir, ignore_errors=True)
    atexit.register(shutil.rmtree, temp_test_dir, ignore_errors=True)

    splits = {
        temp_train_dir: train_df,
        temp_valid_dir: valid_df,
        temp_test_dir: test_df,
    }

    for temp_dir, df in splits.items():
        for cif_id in df["cif_id"]:
            src = os.path.join(full_set_dir, f"{cif_id}.cif")
            dst = os.path.join(temp_dir, f"{cif_id}.cif")
            shutil.copy(src, dst)

    train_df.to_csv(
        os.path.join(temp_train_dir, "id_prop.csv"), index=False, header=False
    )
    valid_df.to_csv(
        os.path.join(temp_valid_dir, "id_prop.csv"), index=False, header=False
    )
    test_df.to_csv(
        os.path.join(temp_test_dir, "id_prop.csv"), index=False, header=False
    )

    shutil.copy(os.path.join(full_set_dir, "atom_init.json"), temp_train_dir)
    shutil.copy(os.path.join(full_set_dir, "atom_init.json"), temp_valid_dir)
    shutil.copy(os.path.join(full_set_dir, "atom_init.json"), temp_test_dir)

    if train_force_dir is not None:
        df_force = pd.read_csv(
            os.path.join(train_force_dir, "id_prop.csv"),
            header=None,
            names=["cif_id", "property"],
        )

        train_df = pd.concat([train_df, df_force])
        train_df.to_csv(
            os.path.join(temp_train_dir, "id_prop.csv"), index=False, header=False
        )

        for cif_id in df_force["cif_id"]:
            src = os.path.join(train_force_dir, f"{cif_id}.cif")
            dst = os.path.join(temp_train_dir, f"{cif_id}.cif")
            shutil.copy(src, dst)

    return temp_train_dir, temp_valid_dir, temp_test_dir

cgcnn2.model

ConvLayer

Bases: Module

Convolutional layer for graph data.

Performs a convolutional operation on graphs, updating atom features based on their neighbors.

Source code in cgcnn2/model.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
class ConvLayer(nn.Module):
    """
    Convolutional layer for graph data.

    Performs a convolutional operation on graphs, updating atom features based on their neighbors.
    """

    def __init__(self, atom_fea_len: int, nbr_fea_len: int) -> None:
        """
        Initialize the ConvLayer.

        Args:
            atom_fea_len (int): Number of atom hidden features.
            nbr_fea_len (int): Number of bond (neighbor) features.
        """
        super().__init__()
        self.atom_fea_len = atom_fea_len
        self.nbr_fea_len = nbr_fea_len
        self.fc_full = nn.Linear(
            2 * self.atom_fea_len + self.nbr_fea_len, 2 * self.atom_fea_len
        )
        self.sigmoid = nn.Sigmoid()
        self.softplus1 = nn.Softplus()
        self.bn1 = nn.BatchNorm1d(2 * self.atom_fea_len)
        self.bn2 = nn.BatchNorm1d(self.atom_fea_len)
        self.softplus2 = nn.Softplus()

    def forward(self, atom_in_fea, nbr_fea, nbr_fea_idx):
        """
        Forward pass.

        `N`: Total number of atoms in the batch
        `M`: Max number of neighbors

        Args:
            atom_in_fea (torch.Tensor): Tensor of shape `(N, atom_fea_len)` containing the atom features before convolution.
            nbr_fea (torch.Tensor): Tensor of shape `(N, M, nbr_fea_len)` holding the bond features for each atom's `M` neighbors.
            nbr_fea_idx (torch.LongTensor): Tensor of shape `(N, M)` with the indices of the `M` neighbors of each atom.

        Returns:
            atom_out_fea (torch.Tensor): Tensor of shape `(N, atom_fea_len)` with the atom features after convolution.

        """
        N, M = nbr_fea_idx.shape
        # Convolution
        atom_nbr_fea = atom_in_fea[nbr_fea_idx, :]
        total_nbr_fea = torch.cat(
            [
                atom_in_fea.unsqueeze(1).expand(N, M, self.atom_fea_len),
                atom_nbr_fea,
                nbr_fea,
            ],
            dim=2,
        )
        total_gated_fea = self.fc_full(total_nbr_fea)
        total_gated_fea = self.bn1(
            total_gated_fea.view(-1, self.atom_fea_len * 2)
        ).view(N, M, self.atom_fea_len * 2)
        nbr_filter, nbr_core = total_gated_fea.chunk(2, dim=2)
        nbr_filter = self.sigmoid(nbr_filter)
        nbr_core = self.softplus1(nbr_core)
        nbr_sumed = torch.sum(nbr_filter * nbr_core, dim=1)
        nbr_sumed = self.bn2(nbr_sumed)
        atom_out_fea = self.softplus2(atom_in_fea + nbr_sumed)
        return atom_out_fea

__init__(atom_fea_len, nbr_fea_len)

Initialize the ConvLayer.

Parameters:

Name Type Description Default
atom_fea_len int

Number of atom hidden features.

required
nbr_fea_len int

Number of bond (neighbor) features.

required
Source code in cgcnn2/model.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def __init__(self, atom_fea_len: int, nbr_fea_len: int) -> None:
    """
    Initialize the ConvLayer.

    Args:
        atom_fea_len (int): Number of atom hidden features.
        nbr_fea_len (int): Number of bond (neighbor) features.
    """
    super().__init__()
    self.atom_fea_len = atom_fea_len
    self.nbr_fea_len = nbr_fea_len
    self.fc_full = nn.Linear(
        2 * self.atom_fea_len + self.nbr_fea_len, 2 * self.atom_fea_len
    )
    self.sigmoid = nn.Sigmoid()
    self.softplus1 = nn.Softplus()
    self.bn1 = nn.BatchNorm1d(2 * self.atom_fea_len)
    self.bn2 = nn.BatchNorm1d(self.atom_fea_len)
    self.softplus2 = nn.Softplus()

forward(atom_in_fea, nbr_fea, nbr_fea_idx)

Forward pass.

N: Total number of atoms in the batch M: Max number of neighbors

Parameters:

Name Type Description Default
atom_in_fea Tensor

Tensor of shape (N, atom_fea_len) containing the atom features before convolution.

required
nbr_fea Tensor

Tensor of shape (N, M, nbr_fea_len) holding the bond features for each atom's M neighbors.

required
nbr_fea_idx LongTensor

Tensor of shape (N, M) with the indices of the M neighbors of each atom.

required

Returns:

Name Type Description
atom_out_fea Tensor

Tensor of shape (N, atom_fea_len) with the atom features after convolution.

Source code in cgcnn2/model.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def forward(self, atom_in_fea, nbr_fea, nbr_fea_idx):
    """
    Forward pass.

    `N`: Total number of atoms in the batch
    `M`: Max number of neighbors

    Args:
        atom_in_fea (torch.Tensor): Tensor of shape `(N, atom_fea_len)` containing the atom features before convolution.
        nbr_fea (torch.Tensor): Tensor of shape `(N, M, nbr_fea_len)` holding the bond features for each atom's `M` neighbors.
        nbr_fea_idx (torch.LongTensor): Tensor of shape `(N, M)` with the indices of the `M` neighbors of each atom.

    Returns:
        atom_out_fea (torch.Tensor): Tensor of shape `(N, atom_fea_len)` with the atom features after convolution.

    """
    N, M = nbr_fea_idx.shape
    # Convolution
    atom_nbr_fea = atom_in_fea[nbr_fea_idx, :]
    total_nbr_fea = torch.cat(
        [
            atom_in_fea.unsqueeze(1).expand(N, M, self.atom_fea_len),
            atom_nbr_fea,
            nbr_fea,
        ],
        dim=2,
    )
    total_gated_fea = self.fc_full(total_nbr_fea)
    total_gated_fea = self.bn1(
        total_gated_fea.view(-1, self.atom_fea_len * 2)
    ).view(N, M, self.atom_fea_len * 2)
    nbr_filter, nbr_core = total_gated_fea.chunk(2, dim=2)
    nbr_filter = self.sigmoid(nbr_filter)
    nbr_core = self.softplus1(nbr_core)
    nbr_sumed = torch.sum(nbr_filter * nbr_core, dim=1)
    nbr_sumed = self.bn2(nbr_sumed)
    atom_out_fea = self.softplus2(atom_in_fea + nbr_sumed)
    return atom_out_fea

CrystalGraphConvNet

Bases: Module

Create a crystal graph convolutional neural network for predicting total material properties.

Source code in cgcnn2/model.py
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
class CrystalGraphConvNet(nn.Module):
    """
    Create a crystal graph convolutional neural network for predicting total
    material properties.
    """

    def __init__(
        self,
        orig_atom_fea_len: int,
        nbr_fea_len: int,
        atom_fea_len: int = 64,
        n_conv: int = 3,
        h_fea_len: int = 128,
        n_h: int = 1,
        classification: bool = False,
    ) -> None:
        """
        Initialize CrystalGraphConvNet.

        Args:
            orig_atom_fea_len (int): Number of atom features in the input.
            nbr_fea_len (int): Number of bond features.
            atom_fea_len (int): Number of hidden atom features in the convolutional layers
            n_conv (int): Number of convolutional layers
            h_fea_len (int): Number of hidden features after pooling
            n_h (int): Number of hidden layers after pooling
            classification (bool): Whether to use classification or regression
        """
        super().__init__()
        self.classification = classification
        self.embedding = nn.Linear(orig_atom_fea_len, atom_fea_len)
        self.convs = nn.ModuleList(
            [
                ConvLayer(atom_fea_len=atom_fea_len, nbr_fea_len=nbr_fea_len)
                for _ in range(n_conv)
            ]
        )
        self.conv_to_fc = nn.Linear(atom_fea_len, h_fea_len)
        self.conv_to_fc_softplus = nn.Softplus()
        if n_h > 1:
            self.fcs = nn.ModuleList(
                [nn.Linear(h_fea_len, h_fea_len) for _ in range(n_h - 1)]
            )
            self.softpluses = nn.ModuleList([nn.Softplus() for _ in range(n_h - 1)])

        if self.classification:
            self.fc_out = nn.Linear(h_fea_len, 2)
        else:
            self.fc_out = nn.Linear(h_fea_len, 1)

        if self.classification:
            self.logsoftmax = nn.LogSoftmax(dim=1)
            self.dropout = nn.Dropout()

    def forward(
        self,
        atom_fea: torch.Tensor,
        nbr_fea: torch.Tensor,
        nbr_fea_idx: torch.LongTensor,
        crystal_atom_idx: list[torch.LongTensor],
    ):
        """
        Forward pass.

        `N`: Total number of atoms in the batch
        `M`: Max number of neighbors
        `N0`: Total number of crystals in the batch

        Args:
            atom_fea (torch.Tensor): Tensor of shape `(N, orig_atom_fea_len)` containing the atom features from atom type.
            nbr_fea (torch.Tensor): Tensor of shape `(N, M, nbr_fea_len)` containing the bond features of each atom's `M` neighbors.
            nbr_fea_idx (torch.LongTensor): Tensor of shape `(N, M)` containing the indices of the `M` neighbors of each atom.
            crystal_atom_idx (list of torch.LongTensor): Mapping from the crystal index to atom index.

        Returns:
            out (torch.Tensor):
                • `(n_crystals, 2)` if `classification=True`, containing log-probabilities.
                • `(n_crystals, 1)` if `classification=False`, containing the regression output.
            crys_fea (torch.Tensor): Tensor of shape `(n_crystals, h_fea_len)` containing the pooled crystal embeddings.

        """
        atom_fea = self.embedding(atom_fea)
        for conv_func in self.convs:
            atom_fea = conv_func(atom_fea, nbr_fea, nbr_fea_idx)
        crys_fea = self.pooling(atom_fea, crystal_atom_idx)
        crys_fea = self.conv_to_fc(self.conv_to_fc_softplus(crys_fea))
        crys_fea = self.conv_to_fc_softplus(crys_fea)
        if self.classification:
            crys_fea = self.dropout(crys_fea)
        if hasattr(self, "fcs") and hasattr(self, "softpluses"):
            for fc, softplus in zip(self.fcs, self.softpluses):
                crys_fea = softplus(fc(crys_fea))
        out = self.fc_out(crys_fea)
        if self.classification:
            out = self.logsoftmax(out)
        return out, crys_fea

    def pooling(
        self, atom_fea: torch.Tensor, crystal_atom_idx: list[torch.LongTensor]
    ) -> torch.Tensor:
        """
        Aggregate atom features into crystal-level features by mean pooling.

        Args:
            atom_fea (torch.Tensor): Tensor of shape `(N, atom_fea_len)` containing the atom embeddings for all atoms in the batch.
            crystal_atom_idx (list[torch.LongTensor]): List of tensors, where `crystal_atom_idx[i]` contains the indices of atoms belonging to the `i`-th crystal. The concatenated indices must cover every atom exactly once.

        Returns:
            mean_fea (torch.Tensor): Tensor of shape `(n_crystals, atom_fea_len)` containing the mean-pooled crystal embeddings.
        """
        assert sum(len(idx) for idx in crystal_atom_idx) == atom_fea.shape[0]
        mean_fea = torch.stack(
            [torch.mean(atom_fea[idx], dim=0) for idx in crystal_atom_idx], dim=0
        )
        return mean_fea

__init__(orig_atom_fea_len, nbr_fea_len, atom_fea_len=64, n_conv=3, h_fea_len=128, n_h=1, classification=False)

Initialize CrystalGraphConvNet.

Parameters:

Name Type Description Default
orig_atom_fea_len int

Number of atom features in the input.

required
nbr_fea_len int

Number of bond features.

required
atom_fea_len int

Number of hidden atom features in the convolutional layers

64
n_conv int

Number of convolutional layers

3
h_fea_len int

Number of hidden features after pooling

128
n_h int

Number of hidden layers after pooling

1
classification bool

Whether to use classification or regression

False
Source code in cgcnn2/model.py
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
def __init__(
    self,
    orig_atom_fea_len: int,
    nbr_fea_len: int,
    atom_fea_len: int = 64,
    n_conv: int = 3,
    h_fea_len: int = 128,
    n_h: int = 1,
    classification: bool = False,
) -> None:
    """
    Initialize CrystalGraphConvNet.

    Args:
        orig_atom_fea_len (int): Number of atom features in the input.
        nbr_fea_len (int): Number of bond features.
        atom_fea_len (int): Number of hidden atom features in the convolutional layers
        n_conv (int): Number of convolutional layers
        h_fea_len (int): Number of hidden features after pooling
        n_h (int): Number of hidden layers after pooling
        classification (bool): Whether to use classification or regression
    """
    super().__init__()
    self.classification = classification
    self.embedding = nn.Linear(orig_atom_fea_len, atom_fea_len)
    self.convs = nn.ModuleList(
        [
            ConvLayer(atom_fea_len=atom_fea_len, nbr_fea_len=nbr_fea_len)
            for _ in range(n_conv)
        ]
    )
    self.conv_to_fc = nn.Linear(atom_fea_len, h_fea_len)
    self.conv_to_fc_softplus = nn.Softplus()
    if n_h > 1:
        self.fcs = nn.ModuleList(
            [nn.Linear(h_fea_len, h_fea_len) for _ in range(n_h - 1)]
        )
        self.softpluses = nn.ModuleList([nn.Softplus() for _ in range(n_h - 1)])

    if self.classification:
        self.fc_out = nn.Linear(h_fea_len, 2)
    else:
        self.fc_out = nn.Linear(h_fea_len, 1)

    if self.classification:
        self.logsoftmax = nn.LogSoftmax(dim=1)
        self.dropout = nn.Dropout()

forward(atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx)

Forward pass.

N: Total number of atoms in the batch M: Max number of neighbors N0: Total number of crystals in the batch

Parameters:

Name Type Description Default
atom_fea Tensor

Tensor of shape (N, orig_atom_fea_len) containing the atom features from atom type.

required
nbr_fea Tensor

Tensor of shape (N, M, nbr_fea_len) containing the bond features of each atom's M neighbors.

required
nbr_fea_idx LongTensor

Tensor of shape (N, M) containing the indices of the M neighbors of each atom.

required
crystal_atom_idx list of torch.LongTensor

Mapping from the crystal index to atom index.

required

Returns:

Name Type Description
out Tensor

• (n_crystals, 2) if classification=True, containing log-probabilities. • (n_crystals, 1) if classification=False, containing the regression output.

crys_fea Tensor

Tensor of shape (n_crystals, h_fea_len) containing the pooled crystal embeddings.

Source code in cgcnn2/model.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
def forward(
    self,
    atom_fea: torch.Tensor,
    nbr_fea: torch.Tensor,
    nbr_fea_idx: torch.LongTensor,
    crystal_atom_idx: list[torch.LongTensor],
):
    """
    Forward pass.

    `N`: Total number of atoms in the batch
    `M`: Max number of neighbors
    `N0`: Total number of crystals in the batch

    Args:
        atom_fea (torch.Tensor): Tensor of shape `(N, orig_atom_fea_len)` containing the atom features from atom type.
        nbr_fea (torch.Tensor): Tensor of shape `(N, M, nbr_fea_len)` containing the bond features of each atom's `M` neighbors.
        nbr_fea_idx (torch.LongTensor): Tensor of shape `(N, M)` containing the indices of the `M` neighbors of each atom.
        crystal_atom_idx (list of torch.LongTensor): Mapping from the crystal index to atom index.

    Returns:
        out (torch.Tensor):
            • `(n_crystals, 2)` if `classification=True`, containing log-probabilities.
            • `(n_crystals, 1)` if `classification=False`, containing the regression output.
        crys_fea (torch.Tensor): Tensor of shape `(n_crystals, h_fea_len)` containing the pooled crystal embeddings.

    """
    atom_fea = self.embedding(atom_fea)
    for conv_func in self.convs:
        atom_fea = conv_func(atom_fea, nbr_fea, nbr_fea_idx)
    crys_fea = self.pooling(atom_fea, crystal_atom_idx)
    crys_fea = self.conv_to_fc(self.conv_to_fc_softplus(crys_fea))
    crys_fea = self.conv_to_fc_softplus(crys_fea)
    if self.classification:
        crys_fea = self.dropout(crys_fea)
    if hasattr(self, "fcs") and hasattr(self, "softpluses"):
        for fc, softplus in zip(self.fcs, self.softpluses):
            crys_fea = softplus(fc(crys_fea))
    out = self.fc_out(crys_fea)
    if self.classification:
        out = self.logsoftmax(out)
    return out, crys_fea

pooling(atom_fea, crystal_atom_idx)

Aggregate atom features into crystal-level features by mean pooling.

Parameters:

Name Type Description Default
atom_fea Tensor

Tensor of shape (N, atom_fea_len) containing the atom embeddings for all atoms in the batch.

required
crystal_atom_idx list[LongTensor]

List of tensors, where crystal_atom_idx[i] contains the indices of atoms belonging to the i-th crystal. The concatenated indices must cover every atom exactly once.

required

Returns:

Name Type Description
mean_fea Tensor

Tensor of shape (n_crystals, atom_fea_len) containing the mean-pooled crystal embeddings.

Source code in cgcnn2/model.py
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
def pooling(
    self, atom_fea: torch.Tensor, crystal_atom_idx: list[torch.LongTensor]
) -> torch.Tensor:
    """
    Aggregate atom features into crystal-level features by mean pooling.

    Args:
        atom_fea (torch.Tensor): Tensor of shape `(N, atom_fea_len)` containing the atom embeddings for all atoms in the batch.
        crystal_atom_idx (list[torch.LongTensor]): List of tensors, where `crystal_atom_idx[i]` contains the indices of atoms belonging to the `i`-th crystal. The concatenated indices must cover every atom exactly once.

    Returns:
        mean_fea (torch.Tensor): Tensor of shape `(n_crystals, atom_fea_len)` containing the mean-pooled crystal embeddings.
    """
    assert sum(len(idx) for idx in crystal_atom_idx) == atom_fea.shape[0]
    mean_fea = torch.stack(
        [torch.mean(atom_fea[idx], dim=0) for idx in crystal_atom_idx], dim=0
    )
    return mean_fea

cgcnn2.utils

Normalizer

Normalizes a PyTorch tensor and allows restoring it later.

This class keeps track of the mean and standard deviation of a tensor and provides methods to normalize and denormalize tensors using these statistics.

Source code in cgcnn2/utils.py
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
class Normalizer:
    """
    Normalizes a PyTorch tensor and allows restoring it later.

    This class keeps track of the mean and standard deviation of a tensor and provides methods
    to normalize and denormalize tensors using these statistics.
    """

    def __init__(self, tensor: torch.Tensor) -> None:
        """
        Initialize the Normalizer with a sample tensor to calculate mean and standard deviation.

        Args:
            tensor (torch.Tensor): Sample tensor to compute mean and standard deviation.
        """
        self.mean: torch.Tensor = torch.mean(tensor)
        self.std: torch.Tensor = torch.std(tensor)

    def norm(self, tensor: torch.Tensor) -> torch.Tensor:
        """
        Normalize a tensor using the stored mean and standard deviation.

        Args:
            tensor (torch.Tensor): Tensor to normalize.

        Returns:
            torch.Tensor: Normalized tensor.
        """
        return (tensor - self.mean) / self.std

    def denorm(self, normed_tensor: torch.Tensor) -> torch.Tensor:
        """
        Denormalize a tensor using the stored mean and standard deviation.

        Args:
            normed_tensor (torch.Tensor): Normalized tensor to denormalize.

        Returns:
            torch.Tensor: Denormalized tensor.
        """
        return normed_tensor * self.std + self.mean

    def state_dict(self) -> dict[str, torch.Tensor]:
        """
        Returns the state dictionary containing the mean and standard deviation.

        Returns:
            dict[str, torch.Tensor]: State dictionary.
        """
        return {"mean": self.mean, "std": self.std}

    def load_state_dict(self, state_dict: dict[str, torch.Tensor]) -> None:
        """
        Loads the mean and standard deviation from a state dictionary.

        Args:
            state_dict (dict[str, torch.Tensor]): State dictionary containing 'mean' and 'std'.
        """
        self.mean = state_dict["mean"]
        self.std = state_dict["std"]

__init__(tensor)

Initialize the Normalizer with a sample tensor to calculate mean and standard deviation.

Parameters:

Name Type Description Default
tensor Tensor

Sample tensor to compute mean and standard deviation.

required
Source code in cgcnn2/utils.py
808
809
810
811
812
813
814
815
816
def __init__(self, tensor: torch.Tensor) -> None:
    """
    Initialize the Normalizer with a sample tensor to calculate mean and standard deviation.

    Args:
        tensor (torch.Tensor): Sample tensor to compute mean and standard deviation.
    """
    self.mean: torch.Tensor = torch.mean(tensor)
    self.std: torch.Tensor = torch.std(tensor)

denorm(normed_tensor)

Denormalize a tensor using the stored mean and standard deviation.

Parameters:

Name Type Description Default
normed_tensor Tensor

Normalized tensor to denormalize.

required

Returns:

Type Description
Tensor

torch.Tensor: Denormalized tensor.

Source code in cgcnn2/utils.py
830
831
832
833
834
835
836
837
838
839
840
def denorm(self, normed_tensor: torch.Tensor) -> torch.Tensor:
    """
    Denormalize a tensor using the stored mean and standard deviation.

    Args:
        normed_tensor (torch.Tensor): Normalized tensor to denormalize.

    Returns:
        torch.Tensor: Denormalized tensor.
    """
    return normed_tensor * self.std + self.mean

load_state_dict(state_dict)

Loads the mean and standard deviation from a state dictionary.

Parameters:

Name Type Description Default
state_dict dict[str, Tensor]

State dictionary containing 'mean' and 'std'.

required
Source code in cgcnn2/utils.py
851
852
853
854
855
856
857
858
859
def load_state_dict(self, state_dict: dict[str, torch.Tensor]) -> None:
    """
    Loads the mean and standard deviation from a state dictionary.

    Args:
        state_dict (dict[str, torch.Tensor]): State dictionary containing 'mean' and 'std'.
    """
    self.mean = state_dict["mean"]
    self.std = state_dict["std"]

norm(tensor)

Normalize a tensor using the stored mean and standard deviation.

Parameters:

Name Type Description Default
tensor Tensor

Tensor to normalize.

required

Returns:

Type Description
Tensor

torch.Tensor: Normalized tensor.

Source code in cgcnn2/utils.py
818
819
820
821
822
823
824
825
826
827
828
def norm(self, tensor: torch.Tensor) -> torch.Tensor:
    """
    Normalize a tensor using the stored mean and standard deviation.

    Args:
        tensor (torch.Tensor): Tensor to normalize.

    Returns:
        torch.Tensor: Normalized tensor.
    """
    return (tensor - self.mean) / self.std

state_dict()

Returns the state dictionary containing the mean and standard deviation.

Returns:

Type Description
dict[str, Tensor]

dict[str, torch.Tensor]: State dictionary.

Source code in cgcnn2/utils.py
842
843
844
845
846
847
848
849
def state_dict(self) -> dict[str, torch.Tensor]:
    """
    Returns the state dictionary containing the mean and standard deviation.

    Returns:
        dict[str, torch.Tensor]: State dictionary.
    """
    return {"mean": self.mean, "std": self.std}

cgcnn_descriptor(model, loader, device, verbose)

This function takes a pre-trained CGCNN model and a dataset, runs inference to generate predictions and features from the last layer, and returns the predictions and features. It is not necessary to have target values for the pred set.

Parameters:

Name Type Description Default
model Module

The trained CGCNN model.

required
loader DataLoader

DataLoader for the dataset.

required
device str

The device ('cuda' or 'cpu') where the model will be run.

required
verbose int

The verbosity level of the output.

required

Returns:

Name Type Description
tuple tuple[list[float], list[Tensor]]

A tuple containing: - list: Model predictions - list: Crystal features from the last layer

Notes

This function is intended for use in programmatic downstream analysis, where the user wants to continue downstream analysis using predictions or features (descriptors) generated by the model. For the command-line interface, consider using the cgcnn_pr script instead.

Source code in cgcnn2/utils.py
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
def cgcnn_descriptor(
    model: torch.nn.Module,
    loader: torch.utils.data.DataLoader,
    device: str,
    verbose: int,
) -> tuple[list[float], list[torch.Tensor]]:
    """
    This function takes a pre-trained CGCNN model and a dataset, runs inference
    to generate predictions and features from the last layer, and returns the
    predictions and features. It is not necessary to have target values for the
    pred set.

    Args:
        model (torch.nn.Module): The trained CGCNN model.
        loader (torch.utils.data.DataLoader): DataLoader for the dataset.
        device (str): The device ('cuda' or 'cpu') where the model will be run.
        verbose (int): The verbosity level of the output.

    Returns:
        tuple: A tuple containing:
            - list: Model predictions
            - list: Crystal features from the last layer

    Notes:
        This function is intended for use in programmatic downstream analysis,
        where the user wants to continue downstream analysis using predictions or
        features (descriptors) generated by the model. For the command-line interface,
        consider using the cgcnn_pr script instead.
    """

    model.eval()
    targets_list = []
    outputs_list = []
    crys_feas_list = []
    index = 0

    with torch.inference_mode():
        for input, target, cif_id in loader:
            atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx = input
            atom_fea = atom_fea.to(device)
            nbr_fea = nbr_fea.to(device)
            nbr_fea_idx = nbr_fea_idx.to(device)
            crystal_atom_idx = [idx_map.to(device) for idx_map in crystal_atom_idx]
            target = target.to(device)

            output, crys_fea = model(atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx)

            targets_list.extend(target.cpu().numpy().ravel().tolist())
            outputs_list.extend(output.cpu().numpy().ravel().tolist())
            crys_feas_list.append(crys_fea.cpu().numpy())

            index += 1

            # Extract the true values from cif_id and output tensor
            cif_id_value = cif_id[0] if cif_id and isinstance(cif_id, list) else cif_id
            prediction_value = output.item() if output.numel() == 1 else output.tolist()

            if verbose >= 10:
                logging.info(
                    f"index: {index} | cif id: {cif_id_value} | prediction: {prediction_value}"
                )

    return outputs_list, crys_feas_list

cgcnn_pred(model_path, full_set, verbose=101, cuda=False, num_workers=0)

This function takes the path to a pre-trained CGCNN model and a dataset, runs inference to generate predictions, and returns the predictions. It is not necessary to have target values for the pred set.

Parameters:

Name Type Description Default
model_path str

Path to the file containing the pre-trained model parameters.

required
full_set str

Path to the directory containing all CIF files for the dataset.

required
verbose int

Verbosity level of the output.

101
cuda bool

Whether to use CUDA.

False
num_workers int

Number of subprocesses for data loading.

0

Returns:

Name Type Description
tuple tuple[list[float], list[Tensor]]

A tuple containing: - list: Model predictions - list: Features from the last layer

Notes

This function is intended for use in programmatic downstream analysis, where the user wants to continue downstream analysis using predictions or features (descriptors) generated by the model. For the command-line interface, consider using the cgcnn_pr script instead.

Source code in cgcnn2/utils.py
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
def cgcnn_pred(
    model_path: str,
    full_set: str,
    verbose: int = 101,
    cuda: bool = False,
    num_workers: int = 0,
) -> tuple[list[float], list[torch.Tensor]]:
    """
    This function takes the path to a pre-trained CGCNN model and a dataset,
    runs inference to generate predictions, and returns the predictions. It is
    not necessary to have target values for the pred set.

    Args:
        model_path (str): Path to the file containing the pre-trained model parameters.
        full_set (str): Path to the directory containing all CIF files for the dataset.
        verbose (int): Verbosity level of the output.
        cuda (bool): Whether to use CUDA.
        num_workers (int): Number of subprocesses for data loading.

    Returns:
        tuple: A tuple containing:
            - list: Model predictions
            - list: Features from the last layer

    Notes:
        This function is intended for use in programmatic downstream analysis,
        where the user wants to continue downstream analysis using predictions or
        features (descriptors) generated by the model. For the command-line interface,
        consider using the cgcnn_pr script instead.
    """
    if not os.path.isfile(model_path):
        raise FileNotFoundError(f"=> No model params found at '{model_path}'")

    total_dataset = CIFData_NoTarget(full_set)

    checkpoint = torch.load(
        model_path,
        map_location=lambda storage, loc: storage if not cuda else None,
        weights_only=False,
    )
    structures, _, _ = total_dataset[0]
    orig_atom_fea_len = structures[0].shape[-1]
    nbr_fea_len = structures[1].shape[-1]
    model_args = argparse.Namespace(**checkpoint["args"])
    model = CrystalGraphConvNet(
        orig_atom_fea_len,
        nbr_fea_len,
        atom_fea_len=model_args.atom_fea_len,
        n_conv=model_args.n_conv,
        h_fea_len=model_args.h_fea_len,
        n_h=model_args.n_h,
    )
    if cuda:
        model.cuda()

    normalizer = Normalizer(torch.zeros(3))
    normalizer.load_state_dict(checkpoint["normalizer"])
    model.load_state_dict(checkpoint["state_dict"])

    if verbose >= 100:
        print_checkpoint_info(checkpoint, model_path)

    device = "cuda" if cuda else "cpu"
    model.to(device).eval()

    full_loader = DataLoader(
        total_dataset,
        batch_size=1,
        shuffle=False,
        num_workers=num_workers,
        collate_fn=collate_pool,
        pin_memory=cuda,
    )

    pred, last_layer = cgcnn_descriptor(model, full_loader, device, verbose)

    return pred, last_layer

cgcnn_test(model, loader, device, results_file='results.csv', plot_file='parity_plot.png', axis_limits=None, **kwargs)

This function takes a pre-trained CGCNN model and a test dataset, runs inference to generate predictions, creates a parity plot comparing pred versus true values, and writes the results to a CSV file.

Parameters:

Name Type Description Default
model Module

The pre-trained CGCNN model.

required
loader DataLoader

DataLoader for the dataset.

required
device str

The device ('cuda' or 'cpu') where the model will be run.

required
results_file str

File path for saving results as CSV.

'results.csv'
plot_file str

File path for saving the parity plot.

'parity_plot.png'
axis_limits list

Limits for x-axis (true values) of the parity plot.

None
**kwargs Any

Additional keyword arguments: xlabel (str): x-axis label for the parity plot. ylabel (str): y-axis label for the parity plot.

{}
Notes

This function is intended for use in a command-line interface, providing direct output of results. For programmatic downstream analysis, consider using the API functions instead, i.e. cgcnn_pred and cgcnn_descriptor.

Source code in cgcnn2/utils.py
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
def cgcnn_test(
    model: torch.nn.Module,
    loader: torch.utils.data.DataLoader,
    device: str,
    results_file: str = "results.csv",
    plot_file: str = "parity_plot.png",
    axis_limits: list[float] | None = None,
    **kwargs: Any,
) -> None:
    """
    This function takes a pre-trained CGCNN model and a test dataset, runs
    inference to generate predictions, creates a parity plot comparing pred
    versus true values, and writes the results to a CSV file.

    Args:
        model (torch.nn.Module): The pre-trained CGCNN model.
        loader (torch.utils.data.DataLoader): DataLoader for the dataset.
        device (str): The device ('cuda' or 'cpu') where the model will be run.
        results_file (str, optional): File path for saving results as CSV.
        plot_file (str, optional): File path for saving the parity plot.
        axis_limits (list, optional): Limits for x-axis (true values) of the parity plot.
        **kwargs: Additional keyword arguments:
            xlabel (str): x-axis label for the parity plot.
            ylabel (str): y-axis label for the parity plot.

    Notes:
        This function is intended for use in a command-line interface, providing
        direct output of results. For programmatic downstream analysis, consider
        using the API functions instead, i.e. cgcnn_pred and cgcnn_descriptor.
    """

    # Extract optional plot labels from kwargs
    xlabel = kwargs.get("xlabel", "true")
    ylabel = kwargs.get("ylabel", "pred")

    model.eval()
    targets_list = []
    outputs_list = []
    cif_ids = []

    with torch.inference_mode():
        for input_batch, target, cif_id in loader:
            atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx = input_batch
            atom_fea = atom_fea.to(device)
            nbr_fea = nbr_fea.to(device)
            nbr_fea_idx = nbr_fea_idx.to(device)
            crystal_atom_idx = [idx_map.to(device) for idx_map in crystal_atom_idx]
            target = target.to(device)
            output, _ = model(atom_fea, nbr_fea, nbr_fea_idx, crystal_atom_idx)

            targets_list.extend(target.cpu().numpy().ravel().tolist())
            outputs_list.extend(output.cpu().numpy().ravel().tolist())
            cif_ids.extend(cif_id)

    targets_array = np.array(targets_list)
    outputs_array = np.array(outputs_list)

    # MSE and R2 Score
    mse = np.mean((targets_array - outputs_array) ** 2)
    ss_res = np.sum((targets_array - outputs_array) ** 2)
    ss_tot = np.sum((targets_array - np.mean(targets_array)) ** 2)
    r2 = 1 - ss_res / ss_tot
    logging.info(f"MSE: {mse:.4f}, R2 Score: {r2:.4f}")

    # Save results to CSV
    sorted_rows = sorted(zip(cif_ids, targets_list, outputs_list), key=lambda x: x[0])
    with open(results_file, "w", newline="") as file:
        writer = csv.writer(file)
        writer.writerow(["cif_id", "true", "pred"])
        writer.writerows(sorted_rows)
    logging.info(f"Prediction results have been saved to {results_file}")

    # Create parity plot
    df_full = pd.DataFrame({"true": targets_list, "pred": outputs_list})
    plot_hexbin(df_full, xlabel, ylabel, out_png=plot_file)
    logging.info(f"Hexbin plot has been saved to {plot_file}")

    # If axis limits are provided, save the csv file with the specified limits
    if axis_limits:
        df_clip = df_full[
            (df_full["true"] >= axis_limits[0]) & (df_full["true"] <= axis_limits[1])
        ]
        clipped_file = plot_file.replace(
            ".png", f"_axis_limits_{axis_limits[0]}_{axis_limits[1]}.png"
        )
        plot_hexbin(df_clip, xlabel, ylabel, out_png=clipped_file)
        logging.info(
            f"Hexbin plot clipped to {axis_limits} on true has been saved to {clipped_file}"
        )

get_local_version()

Retrieves the version of the project from the pyproject.toml file.

Returns:

Name Type Description
version str

The version of the project.

Source code in cgcnn2/utils.py
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def get_local_version() -> str:
    """
    Retrieves the version of the project from the pyproject.toml file.

    Returns:
        version (str): The version of the project.
    """
    project_root = Path(__file__).parents[2]
    toml_path = project_root / "pyproject.toml"
    try:
        with toml_path.open("rb") as f:
            data = tomllib.load(f)
            version = data["project"]["version"]
        return version
    except Exception:
        return "unknown"

get_lr(optimizer)

Extracts learning rates from a PyTorch optimizer.

Parameters:

Name Type Description Default
optimizer Optimizer

The PyTorch optimizer to extract learning rates from.

required

Returns:

Name Type Description
learning_rates list[float]

A list of learning rates for each parameter group.

Source code in cgcnn2/utils.py
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
def get_lr(optimizer: torch.optim.Optimizer) -> list[float]:
    """
    Extracts learning rates from a PyTorch optimizer.

    Args:
        optimizer (torch.optim.Optimizer): The PyTorch optimizer to extract learning rates from.

    Returns:
        learning_rates (list[float]): A list of learning rates for each parameter group.
    """

    learning_rates = []
    for param_group in optimizer.param_groups:
        learning_rates.append(param_group["lr"])
    return learning_rates

id_prop_gen(cif_dir)

Generates a CSV file containing IDs and properties of CIF files.

Parameters:

Name Type Description Default
cif_dir str

Directory containing the CIF files.

required
Source code in cgcnn2/utils.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
def id_prop_gen(cif_dir: str) -> None:
    """Generates a CSV file containing IDs and properties of CIF files.

    Args:
        cif_dir (str): Directory containing the CIF files.
    """

    cif_list = glob.glob(f"{cif_dir}/*.cif")

    id_prop_cif = pd.DataFrame(
        {
            "id": [os.path.basename(cif).split(".")[0] for cif in cif_list],
            "prop": [0 for _ in range(len(cif_list))],
        }
    )

    id_prop_cif.to_csv(
        f"{cif_dir}/id_prop.csv",
        index=False,
        header=False,
    )

metrics_text(df, metrics=['mae', 'r2'], metrics_precision='3f', unit=None, unit_scale=1.0)

Create a text string containing the metrics and their values.

Parameters:

Name Type Description Default
df DataFrame

DataFrame containing the true and pred values.

required
metrics list[str]

A list of metrics to be displayed in the plot.

['mae', 'r2']
metrics_precision str

Format string for the metrics.

'3f'
unit str | None

Unit of the property.

None
unit_scale float

Scale factor for the unit.

1.0

Returns: text (str): A text string containing the metrics and their values.

Source code in cgcnn2/utils.py
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
def metrics_text(
    df: pd.DataFrame,
    metrics: list[str] = ["mae", "r2"],
    metrics_precision: str = "3f",
    unit: str | None = None,
    unit_scale: float = 1.0,
) -> str:
    """
    Create a text string containing the metrics and their values.

    Args:
        df (pd.DataFrame): DataFrame containing the true and pred values.
        metrics (list[str]): A list of metrics to be displayed in the plot.
        metrics_precision (str): Format string for the metrics.
        unit (str | None): Unit of the property.
        unit_scale (float): Scale factor for the unit.
    Returns:
        text (str): A text string containing the metrics and their values.
    """

    values: dict[str, float] = {}
    for m in metrics:
        m_lower = m.lower()
        if m_lower == "mae":
            values["MAE"] = np.mean(np.abs(df["true"] - df["pred"])) * unit_scale
        elif m_lower == "mse":
            values["MSE"] = np.mean((df["true"] - df["pred"]) ** 2) * unit_scale
        elif m_lower == "rmse":
            values["RMSE"] = (
                np.sqrt(np.mean((df["true"] - df["pred"]) ** 2)) * unit_scale
            )
        elif m_lower == "r2":
            values["R^2"] = 1 - np.sum((df["true"] - df["pred"]) ** 2) / np.sum(
                (df["true"] - df["true"].mean()) ** 2
            )
        else:
            raise ValueError(f"Unsupported metric: {m}")

    text_lines: list[str] = []
    for name, val in values.items():
        if unit and name == "MSE":
            unit_str = rf"\,\mathrm{{{unit}}}^2"
        elif unit and name != "R^2":
            unit_str = rf"\,\mathrm{{{unit}}}"
        else:
            unit_str = ""

        if name == "R^2":
            latex_name = r"R^2"
        else:
            latex_name = rf"\mathrm{{{name}}}"

        if name == "R^2":
            text_lines.append(rf"${latex_name}: {val:.3f}{unit_str}$")
        else:
            text_lines.append(rf"${latex_name}: {val:.{metrics_precision}}{unit_str}$")
    text = "\n".join(text_lines)

    return text

output_id_gen()

Generates a unique output identifier based on current date and time.

Returns:

Name Type Description
folder_name str

A string in format 'output_mmdd_HHMM' for current date/time.

Source code in cgcnn2/utils.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
def output_id_gen() -> str:
    """
    Generates a unique output identifier based on current date and time.

    Returns:
        folder_name (str): A string in format 'output_mmdd_HHMM' for current date/time.
    """

    now = datetime.now()
    timestamp = now.strftime("%m%d_%H%M")
    folder_name = f"output_{timestamp}"

    return folder_name

plot_convergence(df, xlabel, ylabel, ax=None, y2label=None, ylabel_precision='3f', y2label_precision='3f', colors=('#137DC5', '#BF1922'), xtick_rotation=0, subfigure_label=None, out_png=None)

Create a convergence plot and save it to a file.

Parameters:

Name Type Description Default
df DataFrame

DataFrame containing the metrics values.

required
xlabel str

Label for the x-axis (epochs)

required
ylabel str

Label for the y-axis (metric)

required
ax Axes | None

Axes object to plot the convergence on.

None
y2label str | None

Label for the y2-axis (metric)

None
ylabel_precision str

Format string for the y-axis label.

'3f'
y2label_precision str

Format string for the y2-axis label.

'3f'
colors Sequence[str]

Colors for the lines.

('#137DC5', '#BF1922')
xtick_rotation float

Rotation of the x-axis tick labels.

0
subfigure_label str | None

Label for the subfigure.

None
out_png str | None

Path of the PNG file in which to save the convergence plot.

None
Source code in cgcnn2/utils.py
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
def plot_convergence(
    df: pd.DataFrame,
    xlabel: str,
    ylabel: str,
    ax: plt.Axes | None = None,
    y2label: str | None = None,
    ylabel_precision: str = "3f",
    y2label_precision: str = "3f",
    colors: Sequence[str] = ("#137DC5", "#BF1922"),
    xtick_rotation: float = 0,
    subfigure_label: str | None = None,
    out_png: str | None = None,
) -> None:
    """
    Create a convergence plot and save it to a file.

    Args:
        df (pd.DataFrame): DataFrame containing the metrics values.
        xlabel (str): Label for the x-axis (epochs)
        ylabel (str): Label for the y-axis (metric)
        ax (plt.Axes | None): Axes object to plot the convergence on.
        y2label (str | None): Label for the y2-axis (metric)
        ylabel_precision (str): Format string for the y-axis label.
        y2label_precision (str): Format string for the y2-axis label.
        colors (Sequence[str]): Colors for the lines.
        xtick_rotation (float): Rotation of the x-axis tick labels.
        subfigure_label (str | None): Label for the subfigure.
        out_png (str | None): Path of the PNG file in which to save the convergence plot.

    """

    with plt.rc_context(PLOT_RC_PARAMS):
        if ax is None:
            fig, ax = plt.subplots(figsize=(8, 6), layout="constrained")
        else:
            fig = ax.get_figure()

        x = df[xlabel]
        y = df[ylabel]

        # Primary line (left y‑axis)
        (ln1,) = ax.plot(x, y, label=ylabel, color=colors[0])

        lines = [ln1]
        labels = [ylabel]

        # Optional secondary line (right y‑axis)
        if y2label is not None:
            y2 = df[y2label]
            ax2 = ax.twinx()

            (ln2,) = ax2.plot(x, y2, linestyle="--", label=y2label, color=colors[1])

            lines.append(ln2)
            labels.append(y2label)

            y1_lim = ax.get_ylim()
            y2_lim = ax2.get_ylim()

            ax.set_yticks(np.linspace(y1_lim[0], y1_lim[1], 6))
            ax2.set_yticks(np.linspace(y2_lim[0], y2_lim[1], 6))

            ax.yaxis.set_major_formatter(
                mticker.FormatStrFormatter(f"%.{ylabel_precision}")
            )
            ax2.yaxis.set_major_formatter(
                mticker.FormatStrFormatter(f"%.{y2label_precision}")
            )

            ax.yaxis.set_minor_locator(mticker.AutoMinorLocator(2))
            ax2.yaxis.set_minor_locator(mticker.AutoMinorLocator(2))

            ax.legend(lines, labels, loc="center right")

            ax.spines["left"].set_color(colors[0])
            ax2.spines["left"].set_visible(False)
            ax2.spines["right"].set_color(colors[1])
            ax.spines["right"].set_visible(False)

            ax.tick_params(axis="y", colors=colors[0], which="both")
            ax2.tick_params(axis="y", colors=colors[1], which="both")

        else:
            ax.set_xlabel(xlabel)
            ax.set_ylabel(ylabel)

        ax.tick_params(axis="x", rotation=xtick_rotation)

        ax.grid(True, which="major", alpha=0.3)

        if subfigure_label is not None:
            ax.text(
                0.025,
                0.975,
                subfigure_label,
                transform=ax.transAxes,
                ha="left",
                va="top",
            )

        if out_png is not None:
            fig.savefig(out_png, dpi=300, bbox_inches="tight")

plot_hexbin(df, xlabel, ylabel, ax=None, metrics=['mae', 'r2'], metrics_precision='3f', unit=None, unit_scale=1.0, subfigure_label=None, out_png=None)

Create a hexbin plot and save it to a file.

Parameters:

Name Type Description Default
df DataFrame

DataFrame containing the true and pred values.

required
xlabel str

Label for the x-axis.

required
ylabel str

Label for the y-axis.

required
ax Axes | None

Axes object to plot the hexbin on.

None
metrics list[str]

A list of strings to be displayed in the plot.

['mae', 'r2']
metrics_precision str

Format string for the metrics.

'3f'
unit str | None

Unit of the property.

None
unit_scale float

Scale factor for the unit.

1.0
subfigure_label str | None

Label for the subfigure.

None
out_png str | None

Path of the PNG file in which to save the hexbin plot.

None
Source code in cgcnn2/utils.py
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
def plot_hexbin(
    df: pd.DataFrame,
    xlabel: str,
    ylabel: str,
    ax: plt.Axes | None = None,
    metrics: list[str] = ["mae", "r2"],
    metrics_precision: str = "3f",
    unit: str | None = None,
    unit_scale: float = 1.0,
    subfigure_label: str | None = None,
    out_png: str | None = None,
) -> None:
    """
    Create a hexbin plot and save it to a file.

    Args:
        df (pd.DataFrame): DataFrame containing the true and pred values.
        xlabel (str): Label for the x-axis.
        ylabel (str): Label for the y-axis.
        ax (plt.Axes | None): Axes object to plot the hexbin on.
        metrics (list[str]): A list of strings to be displayed in the plot.
        metrics_precision (str): Format string for the metrics.
        unit (str | None): Unit of the property.
        unit_scale (float): Scale factor for the unit.
        subfigure_label (str | None): Label for the subfigure.
        out_png (str | None): Path of the PNG file in which to save the hexbin plot.

    """

    with plt.rc_context(PLOT_RC_PARAMS):
        if ax is None:
            fig, ax = plt.subplots(figsize=(8, 6), layout="constrained")
        else:
            ax.get_figure()

        hb = ax.hexbin(
            x="true",
            y="pred",
            data=df,
            gridsize=40,
            cmap="viridis",
            bins="log",
        )

        ax.set_xlabel(xlabel)
        ax.set_ylabel(ylabel)

        # Keep axes square
        ax.set_box_aspect(1)

        # Get the current axis limits
        xlim = ax.get_xlim()
        ylim = ax.get_ylim()
        min_val = min(xlim[0], ylim[0])
        max_val = max(xlim[1], ylim[1])

        # Plot y = x reference line (grey dashed)
        ax.plot(
            [min_val, max_val],
            [min_val, max_val],
            linestyle="--",
            color="grey",
            linewidth=2,
        )

        # Restore the original limits
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)

        # add density colorbar put inside the plot
        cax = inset_axes(
            ax, width="3.5%", height="70%", loc="lower right", borderpad=0.5
        )
        plt.colorbar(hb, cax=cax)
        cax.yaxis.set_ticks_position("left")
        cax.yaxis.set_label_position("left")

        # Compute requested metrics
        text = metrics_text(df, metrics, metrics_precision, unit, unit_scale)

        if subfigure_label is not None:
            text = f"{subfigure_label}\n{text}"

        ax.text(
            0.025,
            0.975,
            text,
            transform=ax.transAxes,
            ha="left",
            va="top",
        )

        if out_png is not None:
            plt.savefig(out_png, format="png", dpi=300, bbox_inches="tight")

plot_scatter(df, xlabel, ylabel, ax=None, true_types=['true_train', 'true_valid', 'true_test'], pred_types=['pred_train', 'pred_valid', 'pred_test'], colors=('#137DC5', '#FACF39', '#BF1922', '#F7E8D3', '#B89FDC', '#0F0C08'), legend_labels=None, metrics=['mae', 'r2'], metrics_precision='3f', unit=None, unit_scale=1.0, subfigure_label=None, out_png=None)

Create a scatter plot and save it to a file.

Parameters:

Name Type Description Default
df DataFrame

DataFrame containing the true and pred values.

required
xlabel str

Label for the x-axis.

required
ylabel str

Label for the y-axis.

required
ax Axes | None

Axes object to plot the scatter on.

None
true_types list[str]

A list of true data types to be displayed in the plot.

['true_train', 'true_valid', 'true_test']
pred_types list[str]

A list of pred data types to be displayed in the plot.

['pred_train', 'pred_valid', 'pred_test']
colors Sequence[str]

A list of colors to be used for the data types. Default palette is adapted from Looka 2025 with six colors.

('#137DC5', '#FACF39', '#BF1922', '#F7E8D3', '#B89FDC', '#0F0C08')
legend_labels list[str] | None

A list of labels for the legend.

None
metrics list[str]

Metrics to display in the plot.

['mae', 'r2']
metrics_precision str

Format string for the metrics.

'3f'
unit str | None

Unit of the property.

None
unit_scale float

Scale factor for the unit.

1.0
subfigure_label str | None

Label for the subfigure.

None
out_png str | None

Path of the PNG file in which to save the scatter plot.

None
Source code in cgcnn2/utils.py
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
def plot_scatter(
    df: pd.DataFrame,
    xlabel: str,
    ylabel: str,
    ax: plt.Axes | None = None,
    true_types: list[str] = ["true_train", "true_valid", "true_test"],
    pred_types: list[str] = ["pred_train", "pred_valid", "pred_test"],
    colors: Sequence[str] = (
        "#137DC5",
        "#FACF39",
        "#BF1922",
        "#F7E8D3",
        "#B89FDC",
        "#0F0C08",
    ),
    legend_labels: list[str] | None = None,
    metrics: list[str] = ["mae", "r2"],
    metrics_precision: str = "3f",
    unit: str | None = None,
    unit_scale: float = 1.0,
    subfigure_label: str | None = None,
    out_png: str | None = None,
) -> None:
    """
    Create a scatter plot and save it to a file.

    Args:
        df (pd.DataFrame): DataFrame containing the true and pred values.
        xlabel (str): Label for the x-axis.
        ylabel (str): Label for the y-axis.
        ax (plt.Axes | None): Axes object to plot the scatter on.
        true_types (list[str]): A list of true data types to be displayed in the plot.
        pred_types (list[str]): A list of pred data types to be displayed in the plot.
        colors (Sequence[str]): A list of colors to be used for the data types.
            Default palette is adapted from
            [Looka 2025](https://looka.com/blog/logo-color-trends/) with six colors.
        legend_labels (list[str] | None): A list of labels for the legend.
        metrics (list[str]): Metrics to display in the plot.
        metrics_precision (str): Format string for the metrics.
        unit (str | None): Unit of the property.
        unit_scale (float): Scale factor for the unit.
        subfigure_label (str | None): Label for the subfigure.
        out_png (str | None): Path of the PNG file in which to save the scatter plot.

    """

    with plt.rc_context(PLOT_RC_PARAMS):
        if ax is None:
            fig, ax = plt.subplots(figsize=(8, 6), layout="constrained")
        else:
            ax.get_figure()

        for true_type, pred_type in zip(true_types, pred_types):
            ax.scatter(
                x=true_type,
                y=pred_type,
                data=df,
                c=colors[true_types.index(true_type) % len(colors)],
                alpha=0.5,
            )

        ax.set_xlabel(xlabel)
        ax.set_ylabel(ylabel)

        # Keep axes square
        ax.set_box_aspect(1)

        # Get the current axis limits
        xlim = ax.get_xlim()
        ylim = ax.get_ylim()
        axis_min = min(xlim[0], ylim[0])
        axis_max = max(xlim[1], ylim[1])

        # Plot y = x reference line (grey dashed)
        ax.plot(
            [axis_min, axis_max],
            [axis_min, axis_max],
            linestyle="--",
            color="grey",
            linewidth=2,
        )

        # Restore the original limits
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)

        # Convert test data for metrics calculation
        df_metrics = df.rename(
            columns={
                "true_test": "true",
                "pred_test": "pred",
            }
        )

        # Compute requested metrics
        text = metrics_text(df_metrics, metrics, metrics_precision, unit, unit_scale)

        if subfigure_label is not None:
            text = f"{subfigure_label}\n{text}"

        ax.text(
            0.025,
            0.975,
            text,
            transform=ax.transAxes,
            ha="left",
            va="top",
        )

        if legend_labels is not None:
            if len(legend_labels) != len(true_types):
                raise ValueError(
                    f"legend_labels length ({len(legend_labels)}) must match number of data series ({len(true_types)})"
                )
            ax.legend(legend_labels, loc="lower right")

        if out_png is not None:
            plt.savefig(out_png, format="png", dpi=300, bbox_inches="tight")

print_checkpoint_info(checkpoint, model_path)

Prints the checkpoint information.

Parameters:

Name Type Description Default
checkpoint dict[str, Any]

The checkpoint dictionary.

required
model_path str

The path to the model file.

required
Source code in cgcnn2/utils.py
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
def print_checkpoint_info(checkpoint: dict[str, Any], model_path: str) -> None:
    """
    Prints the checkpoint information.

    Args:
        checkpoint (dict[str, Any]): The checkpoint dictionary.
        model_path (str): The path to the model file.
    """
    epoch = checkpoint.get("epoch", "N/A")
    mse = checkpoint.get("best_mse_error")
    mae = checkpoint.get("best_mae_error")

    metrics = []
    if mse is not None:
        metrics.append(f"MSE={mse:.4f}")
    if mae is not None:
        metrics.append(f"MAE={mae:.4f}")

    metrics_str = ", ".join(metrics) if metrics else "N/A"

    logging.info(
        f"=> Loaded model from '{model_path}' (epoch {epoch}, validation {metrics_str})"
    )

seed_everything(seed)

Seeds the random number generators for Python, NumPy, PyTorch, and PyTorch CUDA.

Parameters:

Name Type Description Default
seed int

The seed value to use for random number generation.

required
Source code in cgcnn2/utils.py
887
888
889
890
891
892
893
894
895
896
897
898
899
900
def seed_everything(seed: int) -> None:
    """
    Seeds the random number generators for Python, NumPy, PyTorch, and PyTorch CUDA.

    Args:
        seed (int): The seed value to use for random number generation.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

setup_logging()

Sets up logging for the project.

Source code in cgcnn2/utils.py
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def setup_logging() -> None:
    """
    Sets up logging for the project.
    """
    logging.basicConfig(
        stream=sys.stdout,
        level=logging.INFO,
        format="%(asctime)s.%(msecs)03d %(levelname)s: %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )
    logging.captureWarnings(True)

    logging.info(f"cgcnn2 version: {cgcnn2.__version__}")
    logging.info(f"cuda version: {torch.version.cuda}")
    logging.info(f"torch version: {torch.__version__}")

unique_structures_clean(dataset_dir, delete_duplicates=False)

Checks for duplicate (structurally equivalent) structures in a directory of CIF files using pymatgen's StructureMatcher and returns the count of unique structures.

Parameters:

Name Type Description Default
dataset_dir str

The path to the dataset containing CIF files.

required
delete_duplicates bool

Whether to delete the duplicate structures.

False

Returns:

Name Type Description
grouped list

A list of lists, where each sublist contains structurally equivalent structures.

Source code in cgcnn2/utils.py
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
def unique_structures_clean(dataset_dir, delete_duplicates=False):
    """
    Checks for duplicate (structurally equivalent) structures in a directory
    of CIF files using pymatgen's StructureMatcher and returns the count
    of unique structures.

    Args:
        dataset_dir (str): The path to the dataset containing CIF files.
        delete_duplicates (bool): Whether to delete the duplicate structures.

    Returns:
        grouped (list): A list of lists, where each sublist contains structurally equivalent structures.
    """
    cif_files = [f for f in os.listdir(dataset_dir) if f.endswith(".cif")]
    structures = []
    filenames = []

    for fname in cif_files:
        full_path = os.path.join(dataset_dir, fname)
        structures.append(Structure.from_file(full_path))
        filenames.append(fname)

    id_to_fname = {id(s): fn for s, fn in zip(structures, filenames)}

    matcher = StructureMatcher()
    grouped = matcher.group_structures(structures)

    grouped_fnames = [[id_to_fname[id(s)] for s in group] for group in grouped]

    if delete_duplicates:
        for file_group in grouped_fnames:
            # keep the first file, delete the rest
            for dup_fname in file_group[1:]:
                os.remove(os.path.join(dataset_dir, dup_fname))

    return grouped_fnames