Handle semantic search better

Return a flat image array when semantic search is activated
Don't unload semantic search, when it's activated.
Extract class for search logic
Add some logging to the whole process
Sort by relevancy
pull/413/head
Niaz Faridani-Rad 2 years ago
parent 64a5a860ae
commit 8612a77eee

@ -1,3 +1,4 @@
import datetime
import operator
from functools import reduce
@ -5,6 +6,7 @@ from django.db.models import Q
from rest_framework import filters
from rest_framework.compat import distinct
import api.util as util
from api.image_similarity import search_similar_embedding
from api.semantic_search.semantic_search import semantic_search_instance
@ -23,13 +25,18 @@ class SemanticSearchFilter(filters.SearchFilter):
if request.user.semantic_search_topk > 0:
query = request.query_params.get("search")
start = datetime.datetime.now()
emb, magnitude = semantic_search_instance.calculate_query_embeddings(query)
semantic_search_instance.unload()
elapsed = (datetime.datetime.now() - start).total_seconds()
util.logger.info(
"finished calculating query embedding - took %.2f seconds" % (elapsed)
)
start = datetime.datetime.now()
image_hashes = search_similar_embedding(
request.user.id, emb, request.user.semantic_search_topk, threshold=27
)
elapsed = (datetime.datetime.now() - start).total_seconds()
util.logger.info("search similar embedding - took %.2f seconds" % (elapsed))
base = queryset
conditions = []
for search_term in search_terms:

@ -0,0 +1,47 @@
from django.db.models import Q
from rest_framework import viewsets
from rest_framework.response import Response
from api.filters import SemanticSearchFilter
from api.models import Photo
from api.views.pagination import HugeResultsSetPagination
from api.views.PhotosGroupedByDate import get_photos_ordered_by_date
from api.views.serializers_serpy import GroupedPhotosSerializer, PigPhotoSerilizer
class SearchListViewSet(viewsets.ModelViewSet):
serializer_class = GroupedPhotosSerializer
pagination_class = HugeResultsSetPagination
filter_backends = (SemanticSearchFilter,)
search_fields = [
"search_captions",
"search_location",
"faces__person__name",
"exif_timestamp",
"image_paths",
]
def get_queryset(self):
return Photo.visible.filter(Q(owner=self.request.user)).order_by(
"-exif_timestamp"
)
def retrieve(self, *args, **kwargs):
return super(SearchListViewSet, self).retrieve(*args, **kwargs)
def list(self, request):
if request.user.semantic_search_topk == 0:
queryset = self.filter_queryset(
Photo.visible.filter(Q(owner=self.request.user)).order_by(
"-exif_timestamp"
)
)
grouped_photos = get_photos_ordered_by_date(queryset)
serializer = GroupedPhotosSerializer(grouped_photos, many=True)
return Response({"results": serializer.data})
else:
queryset = self.filter_queryset(
Photo.visible.filter(Q(owner=self.request.user))
)
serializer = PigPhotoSerilizer(queryset, many=True)
return Response({"results": serializer.data})

@ -36,7 +36,6 @@ from api.autoalbum import delete_missing_photos
from api.directory_watcher import scan_faces, scan_photos
from api.drf_optimize import OptimizeRelatedModelViewSetMetaclass
from api.face_classify import cluster_faces, train_faces
from api.filters import SemanticSearchFilter
from api.models import (
AlbumAuto,
AlbumDate,
@ -65,7 +64,6 @@ from api.views.pagination import (
StandardResultsSetPagination,
TinyResultsSetPagination,
)
from api.views.PhotosGroupedByDate import get_photos_ordered_by_date
from api.views.serializers import (
AlbumAutoListSerializer,
AlbumUserEditSerializer,
@ -82,7 +80,6 @@ from api.views.serializers import (
SharedFromMePhotoThroughSerializer,
UserSerializer,
)
from api.views.serializers_serpy import GroupedPhotosSerializer
from api.views.serializers_serpy import (
PhotoSuperSimpleSerializer as PhotoSuperSimpleSerializerSerpy,
)
@ -211,35 +208,6 @@ class PhotoSimpleListViewSet(viewsets.ModelViewSet):
return super(PhotoSimpleListViewSet, self).list(*args, **kwargs)
class PhotoSuperSimpleSearchListViewSet(viewsets.ModelViewSet):
serializer_class = GroupedPhotosSerializer
pagination_class = HugeResultsSetPagination
filter_backends = (SemanticSearchFilter,)
search_fields = [
"search_captions",
"search_location",
"faces__person__name",
"exif_timestamp",
"image_paths",
]
def get_queryset(self):
return Photo.visible.filter(Q(owner=self.request.user)).order_by(
"-exif_timestamp"
)
def retrieve(self, *args, **kwargs):
return super(PhotoSuperSimpleSearchListViewSet, self).retrieve(*args, **kwargs)
def list(self, request):
queryset = self.filter_queryset(
Photo.visible.filter(Q(owner=self.request.user)).order_by("-exif_timestamp")
)
grouped_photos = get_photos_ordered_by_date(queryset)
serializer = GroupedPhotosSerializer(grouped_photos, many=True)
return Response({"results": serializer.data})
class PhotoSuperSimpleListViewSet(viewsets.ModelViewSet):
queryset = Photo.visible.order_by("-exif_timestamp")
@ -1269,8 +1237,6 @@ class DeleteMissingPhotosView(APIView):
return Response({"status": False})
class TrainFaceView(APIView):
def get(self, request, format=None):
try:

@ -34,13 +34,17 @@ class RetrievalIndex(object):
)
def search_similar(self, user_id, in_embedding, n=100, thres=27.0):
start = datetime.datetime.now()
dist, res_indices = self.indices[user_id].search(
np.array([in_embedding], dtype=np.float32), n
)
res = []
for distance, idx in zip(dist[0], res_indices[0]):
for distance, idx in sorted(zip(dist[0], res_indices[0]), reverse=True):
if distance >= thres:
res.append(self.image_hashes[user_id][idx])
logger.info("searched {} images for user {}".format(n, user_id))
elapsed = (datetime.datetime.now() - start).total_seconds()
logger.info(
"searched for %d images for user %d - took %.2f seconds"
% (n, user_id, elapsed)
)
return res

@ -26,7 +26,7 @@ from rest_framework_simplejwt.serializers import (
)
from rest_framework_simplejwt.views import TokenObtainPairView, TokenRefreshView
from api.views import album_auto, albums, photos, views
from api.views import album_auto, albums, photos, search, views
from nextcloud import views as nextcloud_views
schema_view = get_schema_view(
@ -173,9 +173,7 @@ router.register(
r"api/photos/favorites", photos.FavoritePhotoListViewset, basename="photo"
)
router.register(r"api/photos/hidden", photos.HiddenPhotoListViewset, basename="photo")
router.register(
r"api/photos/searchlist", views.PhotoSuperSimpleSearchListViewSet, basename="photo"
)
router.register(r"api/photos/searchlist", search.SearchListViewSet, basename="photo")
router.register(r"api/photos/public", photos.PublicPhotoListViewset, basename="photo")

Loading…
Cancel
Save