Shortcuts

Source code for qdrant_client.local.persistence

import base64
import dbm
import logging
import pickle
import sqlite3
from pathlib import Path
from typing import Iterable, Optional

from qdrant_client.http import models

STORAGE_FILE_NAME_OLD = "storage.dbm"
STORAGE_FILE_NAME = "storage.sqlite"


[docs]def try_migrate_to_sqlite(location: str) -> None: dbm_path = Path(location) / STORAGE_FILE_NAME_OLD sql_path = Path(location) / STORAGE_FILE_NAME if sql_path.exists(): return if not dbm_path.exists(): return try: dbm_storage = dbm.open(str(dbm_path), "c") con = sqlite3.connect(str(sql_path)) cur = con.cursor() # Create table cur.execute("CREATE TABLE IF NOT EXISTS points (id TEXT PRIMARY KEY, point BLOB)") for key in dbm_storage.keys(): value = dbm_storage[key] if isinstance(key, str): key = key.encode("utf-8") key = pickle.loads(key) sqlite_key = CollectionPersistence.encode_key(key) # Insert a row of data cur.execute( "INSERT INTO points VALUES (?, ?)", ( sqlite_key, sqlite3.Binary(value), ), ) con.commit() con.close() dbm_storage.close() dbm_path.unlink() except Exception as e: logging.error("Failed to migrate dbm to sqlite:", e) logging.error( "Please try to use previous version of qdrant-client or re-create collection" ) raise e
[docs]class CollectionPersistence: CHECK_SAME_THREAD: Optional[bool] = None
[docs] @classmethod def encode_key(cls, key: models.ExtendedPointId) -> str: return base64.b64encode(pickle.dumps(key)).decode("utf-8")
def __init__(self, location: str, force_disable_check_same_thread: bool = False): """ Create or load a collection from the local storage. Args: location: path to the collection directory. """ try_migrate_to_sqlite(location) self.location = Path(location) / STORAGE_FILE_NAME self.location.parent.mkdir(exist_ok=True, parents=True) if self.CHECK_SAME_THREAD is None and force_disable_check_same_thread is False: with sqlite3.connect(":memory:") as tmp_conn: # it is unsafe to use `sqlite3.threadsafety` until python3.11 since it was hardcoded to 1, thus we # need to fetch threadsafe with a query # THREADSAFE = 0: Threads may not share the module # THREADSAFE = 1: Threads may share the module, connections and cursors. Default for Linux. # THREADSAFE = 2: Threads may share the module, but not connections. Default for macOS. threadsafe = tmp_conn.execute( "select * from pragma_compile_options where compile_options like 'THREADSAFE=%'" ).fetchone()[0] self.__class__.CHECK_SAME_THREAD = threadsafe != "THREADSAFE=1" if force_disable_check_same_thread: self.__class__.CHECK_SAME_THREAD = False self.storage = sqlite3.connect( str(self.location), check_same_thread=self.CHECK_SAME_THREAD # type: ignore ) self._ensure_table()
[docs] def close(self) -> None: self.storage.close()
def _ensure_table(self) -> None: cursor = self.storage.cursor() cursor.execute("CREATE TABLE IF NOT EXISTS points (id TEXT PRIMARY KEY, point BLOB)") self.storage.commit()
[docs] def persist(self, point: models.PointStruct) -> None: """ Persist a point in the local storage. Args: point: point to persist """ key = self.encode_key(point.id) value = pickle.dumps(point) cursor = self.storage.cursor() # Insert or update by key cursor.execute( "INSERT OR REPLACE INTO points VALUES (?, ?)", ( key, sqlite3.Binary(value), ), ) self.storage.commit()
[docs] def delete(self, point_id: models.ExtendedPointId) -> None: """ Delete a point from the local storage. Args: point_id: id of the point to delete """ key = self.encode_key(point_id) cursor = self.storage.cursor() cursor.execute( "DELETE FROM points WHERE id = ?", (key,), ) self.storage.commit()
[docs] def load(self) -> Iterable[models.PointStruct]: """ Load a point from the local storage. Returns: point: loaded point """ cursor = self.storage.cursor() cursor.execute("SELECT point FROM points") for row in cursor.fetchall(): yield pickle.loads(row[0])
[docs]def test_persistence() -> None: import tempfile with tempfile.TemporaryDirectory() as tmpdir: persistence = CollectionPersistence(tmpdir) point = models.PointStruct(id=1, vector=[1.0, 2.0, 3.0], payload={"a": 1}) persistence.persist(point) for loaded_point in persistence.load(): assert loaded_point == point break del persistence persistence = CollectionPersistence(tmpdir) for loaded_point in persistence.load(): assert loaded_point == point break persistence.delete(point.id) persistence.delete(point.id) for _ in persistence.load(): assert False, "Should not load anything"