Shortcuts

Source code for qdrant_client.local.persistence

import base64
import dbm
import logging
import pickle
import sqlite3
from pathlib import Path
from typing import Iterable, Optional

from qdrant_client.http import models

STORAGE_FILE_NAME_OLD = "storage.dbm"
STORAGE_FILE_NAME = "storage.sqlite"


[docs]def try_migrate_to_sqlite(location: str) -> None: dbm_path = Path(location) / STORAGE_FILE_NAME_OLD sql_path = Path(location) / STORAGE_FILE_NAME if sql_path.exists(): return if not dbm_path.exists(): return try: dbm_storage = dbm.open(str(dbm_path), "c") con = sqlite3.connect(str(sql_path)) cur = con.cursor() # Create table cur.execute("CREATE TABLE IF NOT EXISTS points (id TEXT PRIMARY KEY, point BLOB)") for key in dbm_storage.keys(): value = dbm_storage[key] if isinstance(key, str): key = key.encode("utf-8") key = pickle.loads(key) sqlite_key = CollectionPersistence.encode_key(key) # Insert a row of data cur.execute( "INSERT INTO points VALUES (?, ?)", ( sqlite_key, sqlite3.Binary(value), ), ) con.commit() con.close() dbm_storage.close() dbm_path.unlink() except Exception as e: logging.error("Failed to migrate dbm to sqlite:", e) logging.error( "Please try to use previous version of qdrant-client or re-create collection" ) raise e
[docs]class CollectionPersistence: CHECK_SAME_THREAD: Optional[bool] = None
[docs] @classmethod def encode_key(cls, key: models.ExtendedPointId) -> str: return base64.b64encode(pickle.dumps(key)).decode("utf-8")
def __init__(self, location: str, force_disable_check_same_thread: bool = False): """ Create or load a collection from the local storage. Args: location: path to the collection directory. """ try_migrate_to_sqlite(location) self.location = Path(location) / STORAGE_FILE_NAME self.location.parent.mkdir(exist_ok=True, parents=True) if self.CHECK_SAME_THREAD is None and force_disable_check_same_thread is False: with sqlite3.connect(":memory:") as tmp_conn: # it is unsafe to use `sqlite3.threadsafety` until python3.11 since it was hardcoded to 1, thus we # need to fetch threadsafe with a query # THREADSAFE = 0: Threads may not share the module # THREADSAFE = 1: Threads may share the module, connections and cursors. Default for Linux. # THREADSAFE = 2: Threads may share the module, but not connections. Default for macOS. threadsafe = tmp_conn.execute( "select * from pragma_compile_options where compile_options like 'THREADSAFE=%'" ).fetchone()[0] self.__class__.CHECK_SAME_THREAD = threadsafe != "THREADSAFE=1" if force_disable_check_same_thread: self.__class__.CHECK_SAME_THREAD = False self.storage = sqlite3.connect( str(self.location), check_same_thread=self.CHECK_SAME_THREAD # type: ignore ) self._ensure_table()
[docs] def close(self) -> None: self.storage.close()
def _ensure_table(self) -> None: cursor = self.storage.cursor() cursor.execute("CREATE TABLE IF NOT EXISTS points (id TEXT PRIMARY KEY, point BLOB)") self.storage.commit()
[docs] def persist(self, point: models.PointStruct) -> None: """ Persist a point in the local storage. Args: point: point to persist """ key = self.encode_key(point.id) value = pickle.dumps(point) cursor = self.storage.cursor() # Insert or update by key cursor.execute( "INSERT OR REPLACE INTO points VALUES (?, ?)", ( key, sqlite3.Binary(value), ), ) self.storage.commit()
[docs] def delete(self, point_id: models.ExtendedPointId) -> None: """ Delete a point from the local storage. Args: point_id: id of the point to delete """ key = self.encode_key(point_id) cursor = self.storage.cursor() cursor.execute( "DELETE FROM points WHERE id = ?", (key,), ) self.storage.commit()
[docs] def load(self) -> Iterable[models.PointStruct]: """ Load a point from the local storage. Returns: point: loaded point """ cursor = self.storage.cursor() cursor.execute("SELECT point FROM points") for row in cursor.fetchall(): yield pickle.loads(row[0])
[docs]def test_persistence() -> None: import tempfile with tempfile.TemporaryDirectory() as tmpdir: persistence = CollectionPersistence(tmpdir) point = models.PointStruct(id=1, vector=[1.0, 2.0, 3.0], payload={"a": 1}) persistence.persist(point) for loaded_point in persistence.load(): assert loaded_point == point break del persistence persistence = CollectionPersistence(tmpdir) for loaded_point in persistence.load(): assert loaded_point == point break persistence.delete(point.id) persistence.delete(point.id) for _ in persistence.load(): assert False, "Should not load anything"

Qdrant

Learn more about Qdrant vector search project and ecosystem

Discover Qdrant

Similarity Learning

Explore practical problem solving with Similarity Learning

Learn Similarity Learning

Community

Find people dealing with similar problems and get answers to your questions

Join Community