Spaces:
Runtime error
Runtime error
| import zipfile | |
| import os.path as osp | |
| # import lmdb | |
| import logging | |
| from PIL import Image | |
| import pickle | |
| import io | |
| import glob | |
| import os | |
| from pathlib import Path | |
| import time | |
| from threading import Thread | |
| from PIL import ImageFile | |
| ImageFile.LOAD_TRUNCATED_IMAGES = True | |
| home = str(Path.home()) | |
| abs_blob_path=os.path.realpath("/mnt/blob/") | |
| CACHE_FOLDER=os.path.join(home,"caching") | |
| USE_CACHE=True | |
| def norm(path): | |
| assert "*" not in path | |
| return os.path.realpath(os.path.abspath(path)) | |
| def in_blob(file): | |
| if abs_blob_path in file: | |
| return True | |
| else: | |
| return False | |
| def map_name(file): | |
| path=norm(file) | |
| path=path.lstrip(abs_blob_path+"/") | |
| path=path.replace("/","_") | |
| assert len(path)<250 | |
| return path | |
| def preload(db,sync=False): | |
| if sync: | |
| db.initialize() | |
| else: | |
| p = Thread(target=db.initialize) | |
| p.start() | |
| def get_keys_from_lmdb(db): | |
| with db.begin(write=False) as txn: | |
| return list(txn.cursor().iternext(values=False)) | |
| def decode_img(byteflow): | |
| try: | |
| img=Image.open(io.BytesIO(byteflow)).convert("RGB") | |
| img.load() | |
| except: | |
| img = Image.open("white.jpeg").convert("RGB") | |
| img.load() | |
| return img | |
| def decode_text(byteflow): | |
| return pickle.loads(byteflow) | |
| decode_funcs={ | |
| "image": decode_img, | |
| "text": decode_text | |
| } | |
| class ZipManager: | |
| def __init__(self, zip_path,data_type,prefix=None) -> None: | |
| self.decode_func=decode_funcs[data_type] | |
| self.zip_path=zip_path | |
| self._init=False | |
| preload(self) | |
| def deinitialze(self): | |
| self.zip_fd.close() | |
| del self.zip_fd | |
| self._init = False | |
| def initialize(self,close=True): | |
| self.zip_fd = zipfile.ZipFile(self.zip_path, mode="r") | |
| if not hasattr(self,"_keys"): | |
| self._keys = self.zip_fd.namelist() | |
| self._init = True | |
| if close: | |
| self.deinitialze() | |
| def keys(self): | |
| while not hasattr(self,"_keys"): | |
| time.sleep(0.1) | |
| return self._keys | |
| def get(self, name): | |
| if not self._init: | |
| self.initialize(close=False) | |
| byteflow = self.zip_fd.read(name) | |
| return self.decode_func(byteflow) | |
| class MultipleZipManager: | |
| def __init__(self, files: list, data_type, sync=True): | |
| self.files = files | |
| self._is_init = False | |
| self.data_type=data_type | |
| if sync: | |
| print("sync",files) | |
| self.initialize() | |
| else: | |
| print("async",files) | |
| preload(self) | |
| print("initialize over") | |
| def initialize(self): | |
| self.mapping={} | |
| self.managers={} | |
| for file in self.files: | |
| manager = ZipManager(file, self.data_type) | |
| self.managers[file]=manager | |
| for file,manager in self.managers.items(): | |
| print(file) | |
| # print("loading") | |
| logging.info(f"{file} loading") | |
| keys=manager.keys | |
| for key in keys: | |
| self.mapping[key]=file | |
| logging.info(f"{file} loaded, size = {len(keys)}") | |
| print("loaded") | |
| self._keys=list(self.mapping.keys()) | |
| self._is_init=True | |
| def keys(self): | |
| while not self._is_init: | |
| time.sleep(0.1) | |
| return self._keys | |
| def get(self, name): | |
| data = self.managers[self.mapping[name]].get(name) | |
| return data | |