Skip to content

Instantly share code, notes, and snippets.

@hgbrian
Created November 30, 2024 18:23
Show Gist options
  • Save hgbrian/cb8d22b56212a67d5d1bf27f054dbfcb to your computer and use it in GitHub Desktop.
Save hgbrian/cb8d22b56212a67d5d1bf27f054dbfcb to your computer and use it in GitHub Desktop.
Run Boltz through Gemini
  • src/boltz/data/parse/schema.py line 507: chain_type = const.chain_type_ids[entity_type.upper()]. The variable entity_type could be "ligand" in this branch, resulting in a KeyError. This should use the chain.type from the parsed_chain object.
--- a/src/boltz/data/parse/schema.py
+++ b/src/boltz/data/parse/schema.py
@@ -504,7 +504,7 @@
             parsed_chain = ParsedChain(
                 entity=entity_id,
                 residues=residues,
-                type=const.chain_type_ids["NONPOLYMER"],
+                type="NONPOLYMER",
             )
         elif (entity_type == "ligand") and ("smiles" in items[0][entity_type]):
             seq = items[0][entity_type]["smiles"]
@@ -523,7 +523,7 @@
             parsed_chain = ParsedChain(
                 entity=entity_id,
                 residues=[residue],
-                type=const.chain_type_ids["NONPOLYMER"],
+                type="NONPOLYMER",
             )
         else:
             msg = f"Invalid entity type: {entity_type}"
@@ -536,7 +536,7 @@
             for chain_name in ids:
                 chains[chain_name] = parsed_chain
                 chain_to_msa[chain_name] = msa
-
+    
     # Check if msa is custom or auto
     if is_msa_custom and is_msa_auto:
         msg = "Cannot mix custom and auto-generated MSAs in the same input!"
@@ -566,7 +566,7 @@
         sym_count[entity_id] = sym_id + 1
 
         # Add residue, atom, bond, data
-        for res in chain.residues:
+        for res in chain.residues:       
             atom_center = atom_idx + res.atom_center
             atom_disto = atom_idx + res.atom_disto
             res_data.append(
@@ -650,6 +650,7 @@
     struct_info = StructureInfo(num_chains=len(chains))
     chain_infos = []
     for chain in chains:
+        chain_type = const.chain_type_ids[chains[chain["name"]].type]
         chain_info = ChainInfo(
             chain_id=int(chain["asym_id"]),
             chain_name=chain["name"],
  • src/boltz/data/msa/mmseqs2.py lines 324, 331: a3m_files.append(f"{path}/bfd.mgnify30.metaeuk30.smag30.a3m"). The path for the bfd a3m file uses periods (.) instead of forward slashes (/) as separators, making the path invalid. This assumes the intended path uses the path prefix, similar to the uniref.a3m file.
--- a/src/boltz/data/msa/mmseqs2.py
+++ b/src/boltz/data/msa/mmseqs2.py
@@ -321,7 +321,7 @@
     else:
         a3m_files = [f"{path}/uniref.a3m"]
         if use_env:
-            a3m_files.append(f"{path}/bfd.mgnify30.metaeuk30.smag30.a3m")
+            a3m_files.append(f"{path}/bfd_mgnify30_metaeuk30_smag30.a3m")
 
     # extract a3m files
     if any(not os.path.isfile(a3m_file) for a3m_file in a3m_files):
  • src/boltz/data/feature/symmetry.py line 274: The line true_coords = best_true_coords.clone() clones best_true_coords, which might not be initialized if the weighted_minimum_rmsd_single function raises an exception. As a result, true_coords could be undefined, leading to errors later. This can be solved by adding a break in the except block in lines 269-271 to prevent the loop to continue after this error.
--- a/src/boltz/data/feature/symmetry.py
+++ b/src/boltz/data/feature/symmetry.py
@@ -269,7 +269,7 @@
                 **args_rmsd,
             )
         except:
-            print("Warning: error in rmsd computation inside symmetry code")
+            print("Warning: error in rmsd computation inside symmetry code") 
             continue
         rmsd = rmsd.item()
 
  • src/boltz/data/feature/featurizer.py, line 494: atom_to_token = one_hot(atom_to_token, num_classes=token_id + 1). The variable token_id refers to the loop variable in the previous loop, and its value at this point will be equal to the number of tokens. However, we need all tokens to be present in the one-hot, thus the code should be num_classes=len(data.tokens).
--- a/src/boltz/data/feature/featurizer.py
+++ b/src/boltz/data/feature/featurizer.py
@@ -494,7 +494,7 @@
     ref_element = one_hot(ref_element, num_classes=const.num_elements)
     atom_to_token = one_hot(atom_to_token, num_classes=token_id + 1)
     token_to_rep_atom = one_hot(token_to_rep_atom, num_classes=len(atom_data))
-    r_set_to_rep_atom = one_hot(r_set_to_rep_atom, num_classes=len(atom_data))
+    atom_to_token = one_hot(atom_to_token, num_classes=len(data.tokens))
 
     # Center the ground truth coordinates
     center = (coords * resolved_mask[None, :, None]).sum(dim=1)
  • src/boltz/model/model.py, line 861 and 865: The logs for val/rmsd and val/best_rmsd are missing the .item() call to detach the tensors from the computation graph and get their scalar values. This can be solved by adding .item() to the arguments of self.log: self.log("val/rmsd", self.rmsd.compute().item(), prog_bar=True, sync_dist=True) and self.log("val/best_rmsd", self.best_rmsd.compute().item(), prog_bar=True, sync_dist=True)
--- a/src/boltz/model/model.py
+++ b/src/boltz/model/model.py
@@ -861,12 +861,12 @@
                 sync_dist=True,
             )
 
-        self.log("val/rmsd", self.rmsd.compute(), prog_bar=True, sync_dist=True)
+        self.log("val/rmsd", self.rmsd.compute().item(), prog_bar=True, sync_dist=True)
         self.rmsd.reset()
 
         self.log(
-            "val/best_rmsd", self.best_rmsd.compute(), prog_bar=True, sync_dist=True
+            "val/best_rmsd", self.best_rmsd.compute().item(), prog_bar=True, sync_dist=True
         )
         self.best_rmsd.reset()
 
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment