- src/boltz/data/parse/schema.py line 507:
chain_type = const.chain_type_ids[entity_type.upper()]
. The variable entity_type
could be "ligand" in this branch, resulting in a KeyError
. This should use the chain.type
from the parsed_chain
object.
--- a/src/boltz/data/parse/schema.py
+++ b/src/boltz/data/parse/schema.py
@@ -504,7 +504,7 @@
parsed_chain = ParsedChain(
entity=entity_id,
residues=residues,
- type=const.chain_type_ids["NONPOLYMER"],
+ type="NONPOLYMER",
)
elif (entity_type == "ligand") and ("smiles" in items[0][entity_type]):
seq = items[0][entity_type]["smiles"]
@@ -523,7 +523,7 @@
parsed_chain = ParsedChain(
entity=entity_id,
residues=[residue],
- type=const.chain_type_ids["NONPOLYMER"],
+ type="NONPOLYMER",
)
else:
msg = f"Invalid entity type: {entity_type}"
@@ -536,7 +536,7 @@
for chain_name in ids:
chains[chain_name] = parsed_chain
chain_to_msa[chain_name] = msa
-
+
# Check if msa is custom or auto
if is_msa_custom and is_msa_auto:
msg = "Cannot mix custom and auto-generated MSAs in the same input!"
@@ -566,7 +566,7 @@
sym_count[entity_id] = sym_id + 1
# Add residue, atom, bond, data
- for res in chain.residues:
+ for res in chain.residues:
atom_center = atom_idx + res.atom_center
atom_disto = atom_idx + res.atom_disto
res_data.append(
@@ -650,6 +650,7 @@
struct_info = StructureInfo(num_chains=len(chains))
chain_infos = []
for chain in chains:
+ chain_type = const.chain_type_ids[chains[chain["name"]].type]
chain_info = ChainInfo(
chain_id=int(chain["asym_id"]),
chain_name=chain["name"],
- src/boltz/data/msa/mmseqs2.py lines 324, 331:
a3m_files.append(f"{path}/bfd.mgnify30.metaeuk30.smag30.a3m")
. The path for the bfd a3m file uses periods (.
) instead of forward slashes (/
) as separators, making the path invalid. This assumes the intended path uses the path
prefix, similar to the uniref.a3m
file.
--- a/src/boltz/data/msa/mmseqs2.py
+++ b/src/boltz/data/msa/mmseqs2.py
@@ -321,7 +321,7 @@
else:
a3m_files = [f"{path}/uniref.a3m"]
if use_env:
- a3m_files.append(f"{path}/bfd.mgnify30.metaeuk30.smag30.a3m")
+ a3m_files.append(f"{path}/bfd_mgnify30_metaeuk30_smag30.a3m")
# extract a3m files
if any(not os.path.isfile(a3m_file) for a3m_file in a3m_files):
- src/boltz/data/feature/symmetry.py line 274: The line
true_coords = best_true_coords.clone()
clones best_true_coords
, which might not be initialized if the weighted_minimum_rmsd_single
function raises an exception. As a result, true_coords
could be undefined, leading to errors later. This can be solved by adding a break
in the except block in lines 269-271 to prevent the loop to continue after this error.
--- a/src/boltz/data/feature/symmetry.py
+++ b/src/boltz/data/feature/symmetry.py
@@ -269,7 +269,7 @@
**args_rmsd,
)
except:
- print("Warning: error in rmsd computation inside symmetry code")
+ print("Warning: error in rmsd computation inside symmetry code")
continue
rmsd = rmsd.item()
- src/boltz/data/feature/featurizer.py, line 494:
atom_to_token = one_hot(atom_to_token, num_classes=token_id + 1)
. The variable token_id
refers to the loop variable in the previous loop, and its value at this point will be equal to the number of tokens. However, we need all tokens to be present in the one-hot, thus the code should be num_classes=len(data.tokens)
.
--- a/src/boltz/data/feature/featurizer.py
+++ b/src/boltz/data/feature/featurizer.py
@@ -494,7 +494,7 @@
ref_element = one_hot(ref_element, num_classes=const.num_elements)
atom_to_token = one_hot(atom_to_token, num_classes=token_id + 1)
token_to_rep_atom = one_hot(token_to_rep_atom, num_classes=len(atom_data))
- r_set_to_rep_atom = one_hot(r_set_to_rep_atom, num_classes=len(atom_data))
+ atom_to_token = one_hot(atom_to_token, num_classes=len(data.tokens))
# Center the ground truth coordinates
center = (coords * resolved_mask[None, :, None]).sum(dim=1)
- src/boltz/model/model.py, line 861 and 865: The logs for
val/rmsd
and val/best_rmsd
are missing the .item()
call to detach the tensors from the computation graph and get their scalar values. This can be solved by adding .item()
to the arguments of self.log
: self.log("val/rmsd", self.rmsd.compute().item(), prog_bar=True, sync_dist=True)
and self.log("val/best_rmsd", self.best_rmsd.compute().item(), prog_bar=True, sync_dist=True)
--- a/src/boltz/model/model.py
+++ b/src/boltz/model/model.py
@@ -861,12 +861,12 @@
sync_dist=True,
)
- self.log("val/rmsd", self.rmsd.compute(), prog_bar=True, sync_dist=True)
+ self.log("val/rmsd", self.rmsd.compute().item(), prog_bar=True, sync_dist=True)
self.rmsd.reset()
self.log(
- "val/best_rmsd", self.best_rmsd.compute(), prog_bar=True, sync_dist=True
+ "val/best_rmsd", self.best_rmsd.compute().item(), prog_bar=True, sync_dist=True
)
self.best_rmsd.reset()