Last active
April 27, 2022 21:50
-
-
Save orbingol/5cbcee7cafcf4e26447d87fe36b6467a to your computer and use it in GitHub Desktop.
Python copy & deepcopy with dicts and slots
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import copy | |
from itertools import chain | |
class CopyTestClass(object): | |
def __init__(self, q, w): | |
self._a = q | |
self._b = w | |
self._cache = {} | |
def __copy__(self): | |
# Create a new instance | |
cls = self.__class__ | |
result = cls.__new__(cls) | |
# Copy all attributes | |
result.__dict__.update(self.__dict__) | |
# Return updated instance | |
return result | |
def __deepcopy__(self, memo): | |
# Create a new instance | |
cls = self.__class__ | |
result = cls.__new__(cls) | |
# Don't copy self reference | |
memo[id(self)] = result | |
# Don't copy the cache - if it exists | |
if hasattr(self, "_cache"): | |
memo[id(self._cache)] = self._cache.__new__(dict) | |
# Deep copy all other attributes | |
for k, v in self.__dict__.items(): | |
setattr(result, k, copy.deepcopy(v, memo)) | |
# Return updated instance | |
return result | |
class CopyTestClassWithSlots(object): | |
__slots__ = ('_a', '_b', '_cache') | |
def __init__(self, q, w): | |
self._a = q | |
self._b = w | |
self._cache = {} | |
def __copy__(self): | |
# Create a new instance | |
cls = self.__class__ | |
result = cls.__new__(cls) | |
# Get all __slots__ of the derived class | |
slots = chain.from_iterable(getattr(s, '__slots__', []) for s in self.__class__.__mro__) | |
# Copy all attributes | |
for var in slots: | |
setattr(result, var, copy.copy(getattr(self, var))) | |
# Return updated instance | |
return result | |
def __deepcopy__(self, memo): | |
# Create a new instance | |
cls = self.__class__ | |
result = cls.__new__(cls) | |
# Don't copy self reference | |
memo[id(self)] = result | |
# Don't copy the cache - if it exists | |
if hasattr(self, "_cache"): | |
memo[id(self._cache)] = self._cache.__new__(dict) | |
# Get all __slots__ of the derived class | |
slots = chain.from_iterable(getattr(s, '__slots__', []) for s in self.__class__.__mro__) | |
# Deep copy all other attributes | |
for var in slots: | |
setattr(result, var, copy.deepcopy(getattr(self, var), memo)) | |
# Return updated instance | |
return result | |
# Testing deep copy with __dict__ | |
test1 = CopyTestClass(q=10, w=[1.0 for _ in range(10)]) | |
test1._cache["coins"] = 120 | |
test1c = copy.deepcopy(test1) | |
# Testing copy with __slots__ | |
test2 = CopyTestClass(q=25, w=None) | |
test2c = copy.copy(test2) | |
# Testing deep copy with __slots__ | |
test3 = CopyTestClass(q=[1.0 for _ in range(4)], w=None) | |
test3._b = "a string" | |
test3._cache['test_me'] = 1020 | |
test3c = copy.deepcopy(test3) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment