EP3542319A1 - Training neural networks using a clustering loss - Google Patents

Training neural networks using a clustering loss

Info

Publication number
EP3542319A1
EP3542319A1 EP17817337.3A EP17817337A EP3542319A1 EP 3542319 A1 EP3542319 A1 EP 3542319A1 EP 17817337 A EP17817337 A EP 17817337A EP 3542319 A1 EP3542319 A1 EP 3542319A1
Authority
EP
European Patent Office
Prior art keywords
assignment
clustering
neural network
training
ground truth
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
EP17817337.3A
Other languages
German (de)
French (fr)
Other versions
EP3542319B1 (en
Inventor
Hyun Oh Song
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
Google LLC
Original Assignee
Google LLC
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by Google LLC filed Critical Google LLC
Publication of EP3542319A1 publication Critical patent/EP3542319A1/en
Application granted granted Critical
Publication of EP3542319B1 publication Critical patent/EP3542319B1/en
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/084Backpropagation, e.g. using gradient descent
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06FELECTRIC DIGITAL DATA PROCESSING
    • G06F18/00Pattern recognition
    • G06F18/20Analysing
    • G06F18/23Clustering techniques
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/0464Convolutional networks [CNN, ConvNet]
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/04Architecture, e.g. interconnection topology
    • G06N3/047Probabilistic or stochastic networks
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/0895Weakly supervised learning, e.g. semi-supervised or self-supervised learning
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N5/00Computing arrangements using knowledge-based models
    • G06N5/04Inference or reasoning models

Definitions

  • This specification relates to training neural networks.
  • Neural networks are machine learning models that employ one or more layers of nonlinear units to predict an output for a received input.
  • Some neural networks include one or more hidden layers in addition to an output layer. The output of each hidden layer is used as input to the next layer in the network, i.e., the next hidden layer or the output layer.
  • Each layer of the network generates an output from a received input in accordance with current values of a respective set of parameters.
  • a recurrent neural network is a neural network that receives an input sequence and generates an output sequence from the input sequence.
  • a recurrent neural network can use some or all of the internal state of the network from a previous time step in computing an output at a current time step.
  • An example of a recurrent neural network is a long short term (LSTM) neural network that includes one or more LSTM memory blocks. Each LSTM memory block can include one or more cells that each include an input gate, a forget gate, and an output gate that allow the cell to store previous states for the cell, e.g., for use in generating a current activation or to be provided to other components of the LSTM neural network.
  • LSTM long short term
  • This specification describes a system implemented as computer programs on one or more computers in one or more locations that trains a neural network that has network parameters and that is configured to receive an input data item and to process the input data item to generate an embedding of the input data item in accordance with the network parameters.
  • the neural network trained in the manner described herein may be used for image classification.
  • the trained neural network can generate embeddings that more accurately reflect the similarity between network inputs.
  • the embeddings generated by the trained neural network can be effectively used as features or representations of the network inputs for a variety of tasks, including feature based retrieval, clustering, near duplicate detection, verification, feature matching, domain adaptation, video based weakly supervised learning, and so on.
  • a neural network trained in accordance with the training method described herein may be used for image classification. More specifically, in these examples, by training the neural network in this manner, the embeddings generated by the neural network can be effectively used in large scale classification tasks, i.e., tasks where the number of classes is very large and the number of examples per class becomes scarce. In this setting, any direct classification or regression methods become impractical due to the prohibitively large number of classes.
  • the described training techniques allow the embeddings generated by the trained neural network to be used to accurately classify network inputs into one of the classes, e.g., by representing each class with a respective medoid and determining the medoid that is closest to the embedding of the network input.
  • many conventional approaches to training neural networks to generate embeddings require computationally intensive pre-processing of the training data before it can be used to train the neural network.
  • many existing techniques require a separate data preparation stage in which the training data has to be first prepared in pairs, i.e., with each pair including a positive and a negative example, in triplets, i.e., with each triplet including an anchor example, a positive example, and a negative example, or in n-pair tuples format, before the training data can be used for training.
  • This procedure has very expensive time and space cost as it often requires duplicating the training data and needs to repeatedly access the disk to determine how to format the training examples.
  • the training technique described in this specification requires little to no pre-processing before a batch of training items is used in training, reducing the computational cost and time required for training the neural network while still, as described above, training the network to effectively generate embeddings.
  • FIG. 1 shows an example neural network training system.
  • FIG. 2 is a flow diagram of an example process for training a neural network using a clustering loss.
  • FIG. 3 is a flow diagram of an example process for determining an update to current values of the parameters of the neural network.
  • FIG. 1 shows an example neural network training system 100.
  • the neural network training system 100 is an example of a system implemented as computer programs on one or more computers in one or more locations, in which the systems, components, and techniques described below can be implemented.
  • the neural network training system 100 is a system that trains a neural network 1 10 on training data 140 to determine trained values of the parameters of the neural network 1 10, referred to in this specification as network parameters, from initial values of the network parameters.
  • the neural network 1 10 is a neural network that is configured to receive an input data item 102 and to process the input data item to generate an embedding 112 of the input data item 102 in accordance with the network parameters.
  • an embedding of a data item is an ordered collection of numeric values, e.g., a vector, that represents the data item.
  • each embedding is a point in a multi-dimensional embedding space.
  • the neural network 1 10 can be configured to receive as input any kind of digital data input and to generate an embedding from the input.
  • the input data items also referred to as network inputs
  • the neural network 1 10 can have any architecture that is appropriate for the type of network inputs processed by the neural network 1 10.
  • the neural network 110 can be a convolutional neural network.
  • the embeddings can be the outputs of an intermediate layer of a convolutional neural network that has been pre-trained for image classification, e.g., the Inception network described in C. Szegedy, W. Liu, Y. Jia, P. Sermanet, S. Reed, D. Anguelov, D. Erhan, V. Vanhoucke, and A. Rabinovich, Going deeper with convolutions, in CVPR, 2015.
  • the embeddings generated by the network 110 can be used for any of a variety of purposes.
  • the system 100 can provide embeddings generated by the trained neural network as input to another system as features of the corresponding network inputs, e.g., for use in performing a machine learning task on the network inputs.
  • Example tasks may include feature based retrieval, clustering, near duplicate detection, verification, feature matching, domain adaptation, video based weakly supervised learning, and so on.
  • the system 100 can use an embedding generated by the trained neural network to classify the corresponding network input.
  • the system can maintain data identifying a respective metoid, i.e., a respective representative point in the embedding space, for each of a set of multiple possible classes.
  • the system 100 can then classify a network input as belonging to the class represented by the metoid that is closest to the embedding generated for the network input by the trained neural network.
  • the training data 140 that is used by the system 100 to train the neural network 110 includes multiple batches of training inputs and a ground truth clustering assignment for the training inputs.
  • the ground truth clustering assignment assigns each training input into a respective cluster of a set of clusters.
  • the set of clusters can include a respective cluster for each possible category or class into which a network input can be classified and the ground truth assignment assigns each data item to the cluster for the category or class into which the data item should be classified.
  • the system 100 trains the neural network 110 on the training data 140 by optimizing a clustering objective 150.
  • the clustering objective 150 is an objective that, for a given batch of multiple training inputs, penalizes the neural network 110 for generating embeddings that do not result in, for each possible clustering assignment other than the ground truth assignment for the batch, an oracle clustering score for the batch being higher than a clustering score for the possible clustering assignment by at least a structured margin between the possible clustering assignment and the ground truth assignment.
  • Each possible clustering assignment defines a clustering of the training examples in the batch by specifying a set of cluster medoids, i.e., a set of representative points in the embedding space, that includes one medoid for each cluster in the set of clusters. The clustering assignment then assigns each training item in the batch to the medoid that is closest to the embedding for the training item.
  • the clustering score (also referred to as a facility location score) for a given clustering assignment measures how close the embeddings for the training items in the batch each are to the closest medoid to the embedding. In particular, in some
  • the facility location function F that generates the clustering scores satisfies:
  • the oracle clustering score measures the quality of the clustering of the training examples in the batch given the ground truth clustering assignment and the network parameters, i.e., the quality of the clustering defined by the embeddings for the training inputs in the batch generated in accordance with the network parameters.
  • the oracle clustering function F that generates the oracle clustering score is expressed as:
  • the structured margin between a given possible clustering assignment and the ground truth clustering assignment measures a quality of the possible clustering assignment relative to the ground truth assignment.
  • the structured margin is based on a normalized mutual information measure between the possible clustering assignment and the ground truth assignment.
  • the structured margin ⁇ between a cluster assignment y and the ground truth assignment y * is expressed as:
  • A(y, y * ) l - NMI(y, y * ),
  • NMI(y, y * ) is the normalized mutual information between the two assignments and satisfies:
  • MI is the mutual information between the two assignments and H is the entropy of an assignment.
  • both the mutual information and the entropy are based on marginal probabilities of clusters in the two assignments and joint probabilities between one cluster in one assignment and another cluster in the other assignment.
  • the marginal probability used for computing the entropy and the mutual information can be estimated as, for a given cluster and a given assignment, the fraction of the training items that are assigned to the given cluster by the given assignment.
  • the joint probability used for computing the entropy and the mutual information between a cluster i in a first assignment and a cluster j in a second assignment can be estimated as the fraction of the training items that are assigned to the cluster i by the first assignment and to the cluster j by the second assignment.
  • the clustering loss function for a batch of training inputs X and a ground truth clustering assignment y * for the batch can then satisfy:
  • is a positive constant value
  • g(S) is a function that assigns each training item to the cluster that is represented by the closest medoid in the assignment S to the embedding for the training item
  • [a] + equals zero if a is less than or equal to zero and a if a is greater than zero.
  • F(X, S; ⁇ ) + yA(g(S), y * ) for a given clustering assignment will be referred to in this specification as the augmented clustering score for the clustering assignment. Training the neural network on this objective is described in more detail below with reference to FIGS. 2 and 3.
  • the system 100 provides data specifying the trained neural network for use in processing new network inputs. That is, the system 100 can output, e.g., by outputting to a user device or by storing in a memory accessible to the system 100, the trained values of the network parameters for later use in processing inputs using the trained neural network. Alternatively or in addition to outputting the trained neural network data, the system 100 can instantiate an instance of the neural network having the trained values of the network parameters, receive inputs to be processed, e.g., through an application programming interface (API) offered by the system, use the trained neural network to process the received inputs to generate embeddings, and then provide the generated embeddings in response to the received inputs.
  • API application programming interface
  • FIG. 2 is a flow diagram of an example process 200 for training a neural network on a batch of training data.
  • the process 200 will be described as being performed by a system of one or more computers located in one or more locations.
  • a neural network training system e.g., the neural network training system 100 of FIG.1, appropriately programmed, can perform the process 200.
  • the system can perform the process 200 multiple times for multiple different batches of training items to determine trained values of the network parameters from initial values of the network parameters.
  • the system obtains a batch of training items and a ground truth assignment of the training items in the batch into a plurality of clusters (step 202).
  • the ground truth assignment assigns each training item in the batch to a respective cluster from the set of clusters.
  • the system processes each training item in the batch using the neural network and in accordance with current values of the network parameters to generate a respective embedding for each of the training items (step 204).
  • the system determines an oracle clustering score for the ground truth assignment based on the embeddings for the training items in the batch (step 206).
  • the oracle clustering score measures the quality of the clustering given the ground truth clustering assignment and the network parameters, i.e., of the clustering defined by the embeddings generated in accordance with the current values of the network parameters.
  • the system adjusts the current values of the network parameters by performing an iteration of a neural network training procedure to optimize, i.e., minimize, the clustering objective using the oracle clustering score (step 208).
  • the training procedure determines an update to the current values of the parameters from a gradient of the clustering objective with respect to the parameters and then applies, e.g., adds, the update to the current values to determine updated values of the parameters.
  • the training procedure can be stochastic gradient descent and the system can multiply the gradient by a learning rate to determine the update and then add the update to the current values of the network parameters. Determining the gradient of the clustering objective is described in more detail below with reference to FIG. 3.
  • FIG. 3 is a flow diagram of an example process 300 for determining an update to current values of the network parameters.
  • the process 300 will be described as being performed by a system of one or more computers located in one or more locations.
  • a neural network training system e.g., the neural network training system 100 of FIG.1, appropriately programmed, can perform the process 300.
  • the system can perform the process 300 during the training of the neural network on a batch of training inputs to determine an update to the current values of the network parameters for the batch.
  • the system can then apply, i.e., add, the updates determined for the inputs in the batch to generate updated values of the network parameters.
  • the system determines the possible clustering assignment other than the ground truth assignment that has the highest augmented clustering score (step 302).
  • the augmented clustering score is the clustering score for the possible clustering assignment plus the structured margin between the possible clustering assignment and the ground truth assignment.
  • the system determines an initial best possible clustering assignment using an iterative loss augmented inference technique.
  • the system adds a new medoid to the clustering assignment. That is, the system starts with a clustering assignment that has zero medoids, i.e., does not assign any embeddings to any clusters, and continues adding medoids until the number of medoids in the clustering assignment equals the number of medoids, i.e., number of clusters, in the ground truth assignment.
  • the system adds to the current clustering assignment the medoid that most increases the augmented clustering score for the clustering assignment.
  • the system modifies the initial best possible clustering assignment using a loss augmented refinement technique to determine the highest scoring possible clustering assignment.
  • the system performs multiple iterations of the refinement technique to determine the highest scoring possible clustering assignment from the initial best possible clustering assignment.
  • the system determines whether modifying the current best possible clustering assignment by performing a pairwise exchange of a current medoid for the cluster with an alternative point in the same cluster according to the current best possible clustering assignment would increase the augmented clustering score for the clustering assignment, and, if so, swaps the current medoid for the alternative point to update the clustering assignment.
  • the number of iterations of the loss augmented refinement technique to be performed may be fixed, e.g., to three, five, or seven iterations, or the system may continue performing the technique until none of the pointwise swaps would improve the augmented clustering scores.
  • the system considers as candidate medoids only the embeddings for the training items in the batch. In some other implementations, the system also considers also points in the embedding space that are not embeddings of training items in the batch, e.g., the entire space of possible points in the space or a predetermined discrete subset of possible points.
  • the system determines a gradient of the clustering obj ective with respect to the network parameters using the highest scoring clustering assignment (step 304).
  • the system can determine the gradient as the difference between the gradient of the augmented clustering scoring function term of the obj ective function with respect to the network parameters and the gradient of the oracle scoring function term of the objective function with respect to the network parameters. That is, the total gradient is a first gradient term, i.e., the gradient of augmented clustering scoring function, minus a second gradient term, i.e., the gradient of the oracle scoring function. More specifically, the first gradient term satisfies:
  • Vg J(Xi; ⁇ ) — f ⁇ Xj* ⁇ t),' #)) is the gradient with respect to the network parameters of f(X t ; 0) - f(X o ; 0).
  • the second gradient term satisfies:
  • is the clustering score for the training inputs that are assigned to cluster k by the ground truth clustering assignment relative to the medoid for the cluster k in the highest scoring assignment.
  • the system can determine the gradients in the first gradient term and the second gradient term with respect to all of the parameters of the neural network using conventional neural network training techniques, i.e., by backpropagating gradients through the neural network.
  • the system only computes the gradient as above if the loss for the batch, i.e., the value of the clustering loss function l(X, y * ) for the batch, is greater than zero. If the loss is less than or equal to zero, the system sets the gradients to zero and does not update the current values of the network parameters.
  • the system determines an update to the current values of the network parameters from the gradient of the clustering objective (step 306). For example, the system can determine the update by applying a learning rate to the gradient.
  • Embodiments of the subject matter and the functional operations described in this specification can be implemented in digital electronic circuitry, in tangibly-embodied computer software or firmware, in computer hardware, including the structures disclosed in this specification and their structural equivalents, or in combinations of one or more of them.
  • Embodiments of the subject matter described in this specification can be implemented as one or more computer programs, i.e., one or more modules of computer program instructions encoded on a tangible non transitory storage medium for execution by, or to control the operation of, data processing apparatus.
  • the computer storage medium can be a machine-readable storage device, a machine-readable storage substrate, a random or serial access memory device, or a combination of one or more of them.
  • the program instructions can be encoded on an artificially generated propagated signal, e.g., a machine-generated electrical, optical, or
  • electromagnetic signal that is generated to encode information for transmission to suitable receiver apparatus for execution by a data processing apparatus.
  • data processing apparatus refers to data processing hardware and encompasses all kinds of apparatus, devices, and machines for processing data, including by way of example a programmable processor, a computer, or multiple processors or computers.
  • the apparatus can also be, or further include, special purpose logic circuitry, e.g., an FPGA (field programmable gate array) or an ASIC (application specific integrated circuit).
  • the apparatus can optionally include, in addition to hardware, code that creates an execution environment for computer programs, e.g., code that constitutes processor firmware, a protocol stack, a database management system, an operating system, or a combination of one or more of them.
  • a computer program which may also be referred to or described as a program, software, a software application, an app, a module, a software module, a script, or code, can be written in any form of programming language, including compiled or interpreted languages, or declarative or procedural languages; and it can be deployed in any form, including as a stand alone program or as a module, component, subroutine, or other unit suitable for use in a computing environment.
  • a program may, but need not, correspond to a file in a file system.
  • a program can be stored in a portion of a file that holds other programs or data, e.g., one or more scripts stored in a markup language document, in a single file dedicated to the program in question, or in multiple coordinated files, e.g., files that store one or more modules, sub programs, or portions of code.
  • a computer program can be deployed to be executed on one computer or on multiple computers that are located at one site or distributed across multiple sites and interconnected by a data communication network.
  • the term “database” is used broadly to refer to any collection of data: the data does not need to be structured in any particular way, or structured at all, and it can be stored on storage devices in one or more locations.
  • the index database can include multiple collections of data, each of which may be organized and accessed differently.
  • engine is used broadly to refer to a software-based system, subsystem, or process that is programmed to perform one or more specific functions.
  • an engine will be implemented as one or more software modules or components, installed on one or more computers in one or more locations. In some cases, one or more computers will be dedicated to a particular engine; in other cases, multiple engines can be installed and running on the same computer or computers.
  • the processes and logic flows described in this specification can be performed by one or more programmable computers executing one or more computer programs to perform functions by operating on input data and generating output.
  • the processes and logic flows can also be performed by special purpose logic circuitry, e.g., an FPGA or an ASIC, or by a combination of special purpose logic circuitry and one or more programmed computers.
  • Computers suitable for the execution of a computer program can be based on general or special purpose microprocessors or both, or any other kind of central processing unit.
  • a central processing unit will receive instructions and data from a read only memory or a random access memory or both.
  • the essential elements of a computer are a central processing unit for performing or executing instructions and one or more memory devices for storing instructions and data.
  • the central processing unit and the memory can be supplemented by, or incorporated in, special purpose logic circuitry.
  • a computer will also include, or be operatively coupled to receive data from or transfer data to, or both, one or more mass storage devices for storing data, e.g., magnetic, magneto optical disks, or optical disks. However, a computer need not have such devices.
  • a computer can be embedded in another device, e.g., a mobile telephone, a personal digital assistant (PDA), a mobile audio or video player, a game console, a Global Positioning System (GPS) receiver, or a portable storage device, e.g., a universal serial bus (USB) flash drive, to name just a few.
  • PDA personal digital assistant
  • GPS Global Positioning System
  • USB universal serial bus
  • Computer readable media suitable for storing computer program instructions and data include all forms of non volatile memory, media and memory devices, including by way of example semiconductor memory devices, e.g., EPROM, EEPROM, and flash memory devices; magnetic disks, e.g., internal hard disks or removable disks; magneto optical disks; and CD ROM and DVD-ROM disks.
  • semiconductor memory devices e.g., EPROM, EEPROM, and flash memory devices
  • magnetic disks e.g., internal hard disks or removable disks
  • magneto optical disks e.g., CD ROM and DVD-ROM disks.
  • embodiments of the subject matter described in this specification can be implemented on a computer having a display device, e.g., a CRT (cathode ray tube) or LCD (liquid crystal display) monitor, for displaying information to the user and a keyboard and a pointing device, e.g., a mouse or a trackball, by which the user can provide input to the computer.
  • a display device e.g., a CRT (cathode ray tube) or LCD (liquid crystal display) monitor
  • keyboard and a pointing device e.g., a mouse or a trackball
  • Other kinds of devices can be used to provide for interaction with a user as well; for example, feedback provided to the user can be any form of sensory feedback, e.g., visual feedback, auditory feedback, or tactile feedback; and input from the user can be received in any form, including acoustic, speech, or tactile input.
  • a computer can interact with a user by sending documents to and receiving documents from a device that is used by the user; for example, by sending web pages to a web browser on a user's device in response to requests received from the web browser.
  • a computer can interact with a user by sending text messages or other forms of message to a personal device, e.g., a smartphone that is running a messaging application, and receiving responsive messages from the user in return.
  • Data processing apparatus for implementing machine learning models can also include, for example, special-purpose hardware accelerator units for processing common and compute-intensive parts of machine learning training or production, i.e., inference, workloads.
  • Machine learning models can be implemented and deployed using a machine learning framework, .e.g., a TensorFlow framework, a Microsoft Cognitive Toolkit framework, an Apache Singa framework, or an Apache MXNet framework.
  • a machine learning framework .e.g., a TensorFlow framework, a Microsoft Cognitive Toolkit framework, an Apache Singa framework, or an Apache MXNet framework.
  • Embodiments of the subject matter described in this specification can be implemented in a computing system that includes a back end component, e.g., as a data server, or that includes a middleware component, e.g., an application server, or that includes a front end component, e.g., a client computer having a graphical user interface, a web browser, or an app through which a user can interact with an implementation of the subject matter described in this specification, or any combination of one or more such back end, middleware, or front end components.
  • the components of the system can be interconnected by any form or medium of digital data communication, e.g., a
  • Examples of communication networks include a local area network (LAN) and a wide area network (WAN), e.g., the Internet.
  • LAN local area network
  • WAN wide area network
  • the computing system can include clients and servers.
  • a client and server are generally remote from each other and typically interact through a communication network. The relationship of client and server arises by virtue of computer programs running on the respective computers and having a client-server relationship to each other.
  • a server transmits data, e.g., an HTML page, to a user device, e.g., for purposes of displaying data to and receiving user input from a user interacting with the device, which acts as a client.
  • Data generated at the user device e.g., a result of the user interaction, can be received at the server from the device.

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Physics & Mathematics (AREA)
  • Data Mining & Analysis (AREA)
  • Artificial Intelligence (AREA)
  • General Engineering & Computer Science (AREA)
  • Evolutionary Computation (AREA)
  • General Physics & Mathematics (AREA)
  • Software Systems (AREA)
  • Computational Linguistics (AREA)
  • Mathematical Physics (AREA)
  • Computing Systems (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Biophysics (AREA)
  • Molecular Biology (AREA)
  • General Health & Medical Sciences (AREA)
  • Biomedical Technology (AREA)
  • Health & Medical Sciences (AREA)
  • Probability & Statistics with Applications (AREA)
  • Bioinformatics & Cheminformatics (AREA)
  • Bioinformatics & Computational Biology (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Evolutionary Biology (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)
  • Image Analysis (AREA)

Abstract

Methods, systems, and apparatus, including computer programs encoded on computer storage media, for training a neural network. One of the methods includes obtaining a batch of training items and a ground truth assignment; processing the training items in the batch using the neural network to generate respective embeddings for each of the training items; and adjusting the current values of the network parameters by performing an iteration of a neural network training procedure to optimize an objective function that penalizes the neural network for generating embeddings that do not result in, for each possible clustering assignment other than the ground truth assignment, the oracle clustering score being higher than a clustering score for the possible clustering assignment by at least a structured margin between the possible clustering assignment and the ground truth assignment.

Description

TRAINING NEURAL NETWORKS USING A CLUSTERING LOSS
BACKGROUND
This specification relates to training neural networks.
Neural networks are machine learning models that employ one or more layers of nonlinear units to predict an output for a received input. Some neural networks include one or more hidden layers in addition to an output layer. The output of each hidden layer is used as input to the next layer in the network, i.e., the next hidden layer or the output layer. Each layer of the network generates an output from a received input in accordance with current values of a respective set of parameters.
Some neural networks are recurrent neural networks. A recurrent neural network is a neural network that receives an input sequence and generates an output sequence from the input sequence. In particular, a recurrent neural network can use some or all of the internal state of the network from a previous time step in computing an output at a current time step. An example of a recurrent neural network is a long short term (LSTM) neural network that includes one or more LSTM memory blocks. Each LSTM memory block can include one or more cells that each include an input gate, a forget gate, and an output gate that allow the cell to store previous states for the cell, e.g., for use in generating a current activation or to be provided to other components of the LSTM neural network.
SUMMARY
This specification describes a system implemented as computer programs on one or more computers in one or more locations that trains a neural network that has network parameters and that is configured to receive an input data item and to process the input data item to generate an embedding of the input data item in accordance with the network parameters. In some specific, non-limiting examples, the neural network trained in the manner described herein may be used for image classification.
Particular embodiments of the subject matter described in this specification can be implemented so as to realize one or more of the following advantages. By training a neural network as described in this specification, i.e., by training the neural network to optimize the described objective, the trained neural network can generate embeddings that more accurately reflect the similarity between network inputs. In particular, by training the neural network as described in this specification, the embeddings generated by the trained neural network can be effectively used as features or representations of the network inputs for a variety of tasks, including feature based retrieval, clustering, near duplicate detection, verification, feature matching, domain adaptation, video based weakly supervised learning, and so on.
In some specific examples, a neural network trained in accordance with the training method described herein may be used for image classification. More specifically, in these examples, by training the neural network in this manner, the embeddings generated by the neural network can be effectively used in large scale classification tasks, i.e., tasks where the number of classes is very large and the number of examples per class becomes scarce. In this setting, any direct classification or regression methods become impractical due to the prohibitively large number of classes. The described training techniques, however, allow the embeddings generated by the trained neural network to be used to accurately classify network inputs into one of the classes, e.g., by representing each class with a respective medoid and determining the medoid that is closest to the embedding of the network input.
Additionally, many conventional approaches to training neural networks to generate embeddings require computationally intensive pre-processing of the training data before it can be used to train the neural network. For example, many existing techniques require a separate data preparation stage in which the training data has to be first prepared in pairs, i.e., with each pair including a positive and a negative example, in triplets, i.e., with each triplet including an anchor example, a positive example, and a negative example, or in n-pair tuples format, before the training data can be used for training. This procedure has very expensive time and space cost as it often requires duplicating the training data and needs to repeatedly access the disk to determine how to format the training examples. By contrast, the training technique described in this specification requires little to no pre-processing before a batch of training items is used in training, reducing the computational cost and time required for training the neural network while still, as described above, training the network to effectively generate embeddings.
The details of one or more embodiments of the subject matter described in this specification are set forth in the accompanying drawings and the description below.
Other features, aspects, and advantages of the subject matter will become apparent from the description, the drawings, and the claims. BRIEF DESCRIPTION OF THE DRAWINGS
FIG. 1 shows an example neural network training system.
FIG. 2 is a flow diagram of an example process for training a neural network using a clustering loss.
FIG. 3 is a flow diagram of an example process for determining an update to current values of the parameters of the neural network.
Like reference numbers and designations in the various drawings indicate like elements.
DETAILED DESCRIPTION
FIG. 1 shows an example neural network training system 100. The neural network training system 100 is an example of a system implemented as computer programs on one or more computers in one or more locations, in which the systems, components, and techniques described below can be implemented.
The neural network training system 100 is a system that trains a neural network 1 10 on training data 140 to determine trained values of the parameters of the neural network 1 10, referred to in this specification as network parameters, from initial values of the network parameters.
The neural network 1 10 is a neural network that is configured to receive an input data item 102 and to process the input data item to generate an embedding 112 of the input data item 102 in accordance with the network parameters. Generally, an embedding of a data item is an ordered collection of numeric values, e.g., a vector, that represents the data item. In other words, each embedding is a point in a multi-dimensional embedding space. Once trained, the positions of embeddings in the multi-dimensional space generated by the neural network 110 can reflect similarities between the data items that the embeddings represent.
The neural network 1 10 can be configured to receive as input any kind of digital data input and to generate an embedding from the input. For example, the input data items, also referred to as network inputs, can be images, portions of documents, text sequences, audio data, and so on.
The neural network 1 10 can have any architecture that is appropriate for the type of network inputs processed by the neural network 1 10. For example, when the network inputs are images, the neural network 110 can be a convolutional neural network. For example, the embeddings can be the outputs of an intermediate layer of a convolutional neural network that has been pre-trained for image classification, e.g., the Inception network described in C. Szegedy, W. Liu, Y. Jia, P. Sermanet, S. Reed, D. Anguelov, D. Erhan, V. Vanhoucke, and A. Rabinovich, Going deeper with convolutions, in CVPR, 2015.
Once trained, the embeddings generated by the network 110 can be used for any of a variety of purposes.
For example, the system 100 can provide embeddings generated by the trained neural network as input to another system as features of the corresponding network inputs, e.g., for use in performing a machine learning task on the network inputs.
Example tasks may include feature based retrieval, clustering, near duplicate detection, verification, feature matching, domain adaptation, video based weakly supervised learning, and so on.
As another example, the system 100 can use an embedding generated by the trained neural network to classify the corresponding network input. In particular, the system can maintain data identifying a respective metoid, i.e., a respective representative point in the embedding space, for each of a set of multiple possible classes. The system 100 can then classify a network input as belonging to the class represented by the metoid that is closest to the embedding generated for the network input by the trained neural network.
The training data 140 that is used by the system 100 to train the neural network 110 includes multiple batches of training inputs and a ground truth clustering assignment for the training inputs. The ground truth clustering assignment assigns each training input into a respective cluster of a set of clusters. For example, in the classification context, the set of clusters can include a respective cluster for each possible category or class into which a network input can be classified and the ground truth assignment assigns each data item to the cluster for the category or class into which the data item should be classified.
The system 100 trains the neural network 110 on the training data 140 by optimizing a clustering objective 150. In particular, the clustering objective 150 is an objective that, for a given batch of multiple training inputs, penalizes the neural network 110 for generating embeddings that do not result in, for each possible clustering assignment other than the ground truth assignment for the batch, an oracle clustering score for the batch being higher than a clustering score for the possible clustering assignment by at least a structured margin between the possible clustering assignment and the ground truth assignment. Each possible clustering assignment defines a clustering of the training examples in the batch by specifying a set of cluster medoids, i.e., a set of representative points in the embedding space, that includes one medoid for each cluster in the set of clusters. The clustering assignment then assigns each training item in the batch to the medoid that is closest to the embedding for the training item.
The clustering score (also referred to as a facility location score) for a given clustering assignment measures how close the embeddings for the training items in the batch each are to the closest medoid to the embedding. In particular, in some
implementations, the facility location function F that generates the clustering scores satisfies:
F(X, S; Θ) = -∑iE ]x] min|| (¾ 0) - /(¾; 0) ||, where |X| is the total number of training inputs X in the batch, the sum is a sum over all of the training inputs in the batch, S is the set of medoids in the given clustering assignment, and, for the z'-th training input Xt, /(¾ 0) is the embedding of the training input generated in accordance with the network parameters 0 and min || ( j; 0)— f(X , 0) || is the distance from the closest medoid in the set of medoids to the embedding of the training input.
The oracle clustering score measures the quality of the clustering of the training examples in the batch given the ground truth clustering assignment and the network parameters, i.e., the quality of the clustering defined by the embeddings for the training inputs in the batch generated in accordance with the network parameters. In particular, the oracle clustering function F that generates the oracle clustering score is expressed as:
Ρ{Χ, Υ*; Θ) {/}; 0)
where |Y| is the total number of clusters in the ground truth clustering assignment y*, the sum is a sum over all of the clusters in the ground truth clustering assignment, i: y* [i] = k is the subset of the training examples in the batch that are clustered into cluster k by the ground truth clustering assignment, and, for the cluster k, Θ
is the largest clustering score for only the training items in the cluster k from among the clustering scores generated when the medoid for the cluster k is any of the embeddings for any of the training items in the cluster k.
The structured margin between a given possible clustering assignment and the ground truth clustering assignment measures a quality of the possible clustering assignment relative to the ground truth assignment. In particular, in some
implementations, the structured margin is based on a normalized mutual information measure between the possible clustering assignment and the ground truth assignment. In particular, in these implementations, the structured margin Δ between a cluster assignment y and the ground truth assignment y* is expressed as:
A(y, y*) = l - NMI(y, y*),
where NMI(y, y*) is the normalized mutual information between the two assignments and satisfies:
where MI is the mutual information between the two assignments and H is the entropy of an assignment.
Generally, both the mutual information and the entropy are based on marginal probabilities of clusters in the two assignments and joint probabilities between one cluster in one assignment and another cluster in the other assignment. The marginal probability used for computing the entropy and the mutual information can be estimated as, for a given cluster and a given assignment, the fraction of the training items that are assigned to the given cluster by the given assignment. The joint probability used for computing the entropy and the mutual information between a cluster i in a first assignment and a cluster j in a second assignment can be estimated as the fraction of the training items that are assigned to the cluster i by the first assignment and to the cluster j by the second assignment.
The clustering loss function for a batch of training inputs X and a ground truth clustering assignment y*for the batch can then satisfy:
l{X, y*) = [max{F(X, S; Θ) + YA(g(S), y*)} - F{X, y*; 0)]+,
where the maximum is over the possible clustering assignments, i.e., possible sets of medoids that have the same number of medoids as there are clusters in the ground truth assignment, γ is a positive constant value, g(S) is a function that assigns each training item to the cluster that is represented by the closest medoid in the assignment S to the embedding for the training item, and [a]+ equals zero if a is less than or equal to zero and a if a is greater than zero. The term F(X, S; Θ) + yA(g(S), y*) for a given clustering assignment will be referred to in this specification as the augmented clustering score for the clustering assignment. Training the neural network on this objective is described in more detail below with reference to FIGS. 2 and 3.
Once the neural network has been trained, the system 100 provides data specifying the trained neural network for use in processing new network inputs. That is, the system 100 can output, e.g., by outputting to a user device or by storing in a memory accessible to the system 100, the trained values of the network parameters for later use in processing inputs using the trained neural network. Alternatively or in addition to outputting the trained neural network data, the system 100 can instantiate an instance of the neural network having the trained values of the network parameters, receive inputs to be processed, e.g., through an application programming interface (API) offered by the system, use the trained neural network to process the received inputs to generate embeddings, and then provide the generated embeddings in response to the received inputs.
FIG. 2 is a flow diagram of an example process 200 for training a neural network on a batch of training data. For convenience, the process 200 will be described as being performed by a system of one or more computers located in one or more locations. For example, a neural network training system, e.g., the neural network training system 100 of FIG.1, appropriately programmed, can perform the process 200.
The system can perform the process 200 multiple times for multiple different batches of training items to determine trained values of the network parameters from initial values of the network parameters.
The system obtains a batch of training items and a ground truth assignment of the training items in the batch into a plurality of clusters (step 202). The ground truth assignment assigns each training item in the batch to a respective cluster from the set of clusters.
The system processes each training item in the batch using the neural network and in accordance with current values of the network parameters to generate a respective embedding for each of the training items (step 204).
The system determines an oracle clustering score for the ground truth assignment based on the embeddings for the training items in the batch (step 206). As described above, the oracle clustering score measures the quality of the clustering given the ground truth clustering assignment and the network parameters, i.e., of the clustering defined by the embeddings generated in accordance with the current values of the network parameters. The system adjusts the current values of the network parameters by performing an iteration of a neural network training procedure to optimize, i.e., minimize, the clustering objective using the oracle clustering score (step 208). Generally, the training procedure determines an update to the current values of the parameters from a gradient of the clustering objective with respect to the parameters and then applies, e.g., adds, the update to the current values to determine updated values of the parameters. For example, the training procedure can be stochastic gradient descent and the system can multiply the gradient by a learning rate to determine the update and then add the update to the current values of the network parameters. Determining the gradient of the clustering objective is described in more detail below with reference to FIG. 3.
FIG. 3 is a flow diagram of an example process 300 for determining an update to current values of the network parameters. For convenience, the process 300 will be described as being performed by a system of one or more computers located in one or more locations. For example, a neural network training system, e.g., the neural network training system 100 of FIG.1, appropriately programmed, can perform the process 300.
The system can perform the process 300 during the training of the neural network on a batch of training inputs to determine an update to the current values of the network parameters for the batch. The system can then apply, i.e., add, the updates determined for the inputs in the batch to generate updated values of the network parameters.
The system determines the possible clustering assignment other than the ground truth assignment that has the highest augmented clustering score (step 302).
As described above, the augmented clustering score is the clustering score for the possible clustering assignment plus the structured margin between the possible clustering assignment and the ground truth assignment.
In some implementations, to determine the highest scoring clustering assignment, the system determines an initial best possible clustering assignment using an iterative loss augmented inference technique. In particular, at each iteration of the inference technique, the system adds a new medoid to the clustering assignment. That is, the system starts with a clustering assignment that has zero medoids, i.e., does not assign any embeddings to any clusters, and continues adding medoids until the number of medoids in the clustering assignment equals the number of medoids, i.e., number of clusters, in the ground truth assignment. At each step of the inference technique, the system adds to the current clustering assignment the medoid that most increases the augmented clustering score for the clustering assignment.
The system then modifies the initial best possible clustering assignment using a loss augmented refinement technique to determine the highest scoring possible clustering assignment. In particular, the system performs multiple iterations of the refinement technique to determine the highest scoring possible clustering assignment from the initial best possible clustering assignment.
At each iteration and for each cluster in a current best possible clustering assignment, the system determines whether modifying the current best possible clustering assignment by performing a pairwise exchange of a current medoid for the cluster with an alternative point in the same cluster according to the current best possible clustering assignment would increase the augmented clustering score for the clustering assignment, and, if so, swaps the current medoid for the alternative point to update the clustering assignment.
The number of iterations of the loss augmented refinement technique to be performed may be fixed, e.g., to three, five, or seven iterations, or the system may continue performing the technique until none of the pointwise swaps would improve the augmented clustering scores.
In some implementations, for both the iterative loss augmented inference technique and the loss augmented refinement technique, the system considers as candidate medoids only the embeddings for the training items in the batch. In some other implementations, the system also considers also points in the embedding space that are not embeddings of training items in the batch, e.g., the entire space of possible points in the space or a predetermined discrete subset of possible points.
The system determines a gradient of the clustering obj ective with respect to the network parameters using the highest scoring clustering assignment (step 304).
In particular, the system can determine the gradient as the difference between the gradient of the augmented clustering scoring function term of the obj ective function with respect to the network parameters and the gradient of the oracle scoring function term of the objective function with respect to the network parameters. That is, the total gradient is a first gradient term, i.e., the gradient of augmented clustering scoring function, minus a second gradient term, i.e., the gradient of the oracle scoring function. More specifically, the first gradient term satisfies:
where the sum is a sum over all of the training inputs in the batch, /(¾ Θ) is the embedding of the z-th training input Xt, Θ) is the closest medoid to the embedding /(¾ Θ) from the medoids in the highest scoring assignment, and
Vg (J(Xi; Θ) — f{Xj*{t),' #)) is the gradient with respect to the network parameters of f(Xt; 0) - f(X o; 0).
The second gradient term satisfies:
where the sum is over the clusters in the ground truth assignment, and
F(X{i:y*[i]=k}> {/* (&)}; Θ is the clustering score for the training inputs that are assigned to cluster k by the ground truth clustering assignment relative to the medoid for the cluster k in the highest scoring assignment.
The system can determine the gradients in the first gradient term and the second gradient term with respect to all of the parameters of the neural network using conventional neural network training techniques, i.e., by backpropagating gradients through the neural network.
In some implementations, the system only computes the gradient as above if the loss for the batch, i.e., the value of the clustering loss function l(X, y*) for the batch, is greater than zero. If the loss is less than or equal to zero, the system sets the gradients to zero and does not update the current values of the network parameters.
The system then determines an update to the current values of the network parameters from the gradient of the clustering objective (step 306). For example, the system can determine the update by applying a learning rate to the gradient.
This specification uses the term "configured" in connection with systems and computer program components. For a system of one or more computers to be configured to perform particular operations or actions means that the system has installed on it software, firmware, hardware, or a combination of them that in operation cause the system to perform the operations or actions. For one or more computer programs to be configured to perform particular operations or actions means that the one or more programs include instructions that, when executed by data processing apparatus, cause the apparatus to perform the operations or actions.
Embodiments of the subject matter and the functional operations described in this specification can be implemented in digital electronic circuitry, in tangibly-embodied computer software or firmware, in computer hardware, including the structures disclosed in this specification and their structural equivalents, or in combinations of one or more of them. Embodiments of the subject matter described in this specification can be implemented as one or more computer programs, i.e., one or more modules of computer program instructions encoded on a tangible non transitory storage medium for execution by, or to control the operation of, data processing apparatus. The computer storage medium can be a machine-readable storage device, a machine-readable storage substrate, a random or serial access memory device, or a combination of one or more of them. Alternatively or in addition, the program instructions can be encoded on an artificially generated propagated signal, e.g., a machine-generated electrical, optical, or
electromagnetic signal, that is generated to encode information for transmission to suitable receiver apparatus for execution by a data processing apparatus.
The term "data processing apparatus" refers to data processing hardware and encompasses all kinds of apparatus, devices, and machines for processing data, including by way of example a programmable processor, a computer, or multiple processors or computers. The apparatus can also be, or further include, special purpose logic circuitry, e.g., an FPGA (field programmable gate array) or an ASIC (application specific integrated circuit). The apparatus can optionally include, in addition to hardware, code that creates an execution environment for computer programs, e.g., code that constitutes processor firmware, a protocol stack, a database management system, an operating system, or a combination of one or more of them.
A computer program, which may also be referred to or described as a program, software, a software application, an app, a module, a software module, a script, or code, can be written in any form of programming language, including compiled or interpreted languages, or declarative or procedural languages; and it can be deployed in any form, including as a stand alone program or as a module, component, subroutine, or other unit suitable for use in a computing environment. A program may, but need not, correspond to a file in a file system. A program can be stored in a portion of a file that holds other programs or data, e.g., one or more scripts stored in a markup language document, in a single file dedicated to the program in question, or in multiple coordinated files, e.g., files that store one or more modules, sub programs, or portions of code. A computer program can be deployed to be executed on one computer or on multiple computers that are located at one site or distributed across multiple sites and interconnected by a data communication network.
In this specification, the term "database" is used broadly to refer to any collection of data: the data does not need to be structured in any particular way, or structured at all, and it can be stored on storage devices in one or more locations. Thus, for example, the index database can include multiple collections of data, each of which may be organized and accessed differently.
Similarly, in this specification the term "engine" is used broadly to refer to a software-based system, subsystem, or process that is programmed to perform one or more specific functions. Generally, an engine will be implemented as one or more software modules or components, installed on one or more computers in one or more locations. In some cases, one or more computers will be dedicated to a particular engine; in other cases, multiple engines can be installed and running on the same computer or computers.
The processes and logic flows described in this specification can be performed by one or more programmable computers executing one or more computer programs to perform functions by operating on input data and generating output. The processes and logic flows can also be performed by special purpose logic circuitry, e.g., an FPGA or an ASIC, or by a combination of special purpose logic circuitry and one or more programmed computers.
Computers suitable for the execution of a computer program can be based on general or special purpose microprocessors or both, or any other kind of central processing unit. Generally, a central processing unit will receive instructions and data from a read only memory or a random access memory or both. The essential elements of a computer are a central processing unit for performing or executing instructions and one or more memory devices for storing instructions and data. The central processing unit and the memory can be supplemented by, or incorporated in, special purpose logic circuitry. Generally, a computer will also include, or be operatively coupled to receive data from or transfer data to, or both, one or more mass storage devices for storing data, e.g., magnetic, magneto optical disks, or optical disks. However, a computer need not have such devices. Moreover, a computer can be embedded in another device, e.g., a mobile telephone, a personal digital assistant (PDA), a mobile audio or video player, a game console, a Global Positioning System (GPS) receiver, or a portable storage device, e.g., a universal serial bus (USB) flash drive, to name just a few.
Computer readable media suitable for storing computer program instructions and data include all forms of non volatile memory, media and memory devices, including by way of example semiconductor memory devices, e.g., EPROM, EEPROM, and flash memory devices; magnetic disks, e.g., internal hard disks or removable disks; magneto optical disks; and CD ROM and DVD-ROM disks.
To provide for interaction with a user, embodiments of the subject matter described in this specification can be implemented on a computer having a display device, e.g., a CRT (cathode ray tube) or LCD (liquid crystal display) monitor, for displaying information to the user and a keyboard and a pointing device, e.g., a mouse or a trackball, by which the user can provide input to the computer. Other kinds of devices can be used to provide for interaction with a user as well; for example, feedback provided to the user can be any form of sensory feedback, e.g., visual feedback, auditory feedback, or tactile feedback; and input from the user can be received in any form, including acoustic, speech, or tactile input. In addition, a computer can interact with a user by sending documents to and receiving documents from a device that is used by the user; for example, by sending web pages to a web browser on a user's device in response to requests received from the web browser. Also, a computer can interact with a user by sending text messages or other forms of message to a personal device, e.g., a smartphone that is running a messaging application, and receiving responsive messages from the user in return.
Data processing apparatus for implementing machine learning models can also include, for example, special-purpose hardware accelerator units for processing common and compute-intensive parts of machine learning training or production, i.e., inference, workloads.
Machine learning models can be implemented and deployed using a machine learning framework, .e.g., a TensorFlow framework, a Microsoft Cognitive Toolkit framework, an Apache Singa framework, or an Apache MXNet framework.
Embodiments of the subject matter described in this specification can be implemented in a computing system that includes a back end component, e.g., as a data server, or that includes a middleware component, e.g., an application server, or that includes a front end component, e.g., a client computer having a graphical user interface, a web browser, or an app through which a user can interact with an implementation of the subject matter described in this specification, or any combination of one or more such back end, middleware, or front end components. The components of the system can be interconnected by any form or medium of digital data communication, e.g., a
communication network. Examples of communication networks include a local area network (LAN) and a wide area network (WAN), e.g., the Internet.
The computing system can include clients and servers. A client and server are generally remote from each other and typically interact through a communication network. The relationship of client and server arises by virtue of computer programs running on the respective computers and having a client-server relationship to each other. In some embodiments, a server transmits data, e.g., an HTML page, to a user device, e.g., for purposes of displaying data to and receiving user input from a user interacting with the device, which acts as a client. Data generated at the user device, e.g., a result of the user interaction, can be received at the server from the device.
While this specification contains many specific implementation details, these should not be construed as limitations on the scope of any invention or on the scope of what may be claimed, but rather as descriptions of features that may be specific to particular embodiments of particular inventions. Certain features that are described in this specification in the context of separate embodiments can also be implemented in combination in a single embodiment. Conversely, various features that are described in the context of a single embodiment can also be implemented in multiple embodiments separately or in any suitable subcombination. Moreover, although features may be described above as acting in certain combinations and even initially be claimed as such, one or more features from a claimed combination can in some cases be excised from the combination, and the claimed combination may be directed to a subcombination or variation of a subcombination.
Similarly, while operations are depicted in the drawings and recited in the claims in a particular order, this should not be understood as requiring that such operations be performed in the particular order shown or in sequential order, or that all illustrated operations be performed, to achieve desirable results. In certain circumstances, multitasking and parallel processing may be advantageous. Moreover, the separation of various system modules and components in the embodiments described above should not be understood as requiring such separation in all embodiments, and it should be understood that the described program components and systems can generally be integrated together in a single software product or packaged into multiple software products.
Particular embodiments of the subject matter have been described. Other embodiments are within the scope of the following claims. For example, the actions recited in the claims can be performed in a different order and still achieve desirable results. As one example, the processes depicted in the accompanying figures do not necessarily require the particular order shown, or sequential order, to achieve desirable results. In some cases, multitasking and parallel processing may be advantageous.

Claims

CLAIMS What is claimed is:
1. A method of training a neural network that has a plurality of network parameters and that is configured to receive an input data item and to process the input data item to generate an embedding of the input data item in accordance with the network parameters, the method comprising:
obtaining a batch of training items and a ground truth assignment of the training items in the batch into a plurality of clusters;
processing the training items in the batch using the neural network and in accordance with current values of the network parameters to generate respective embeddings for each of the training items;
determining an oracle clustering score for the ground truth assignment based on the respective embeddings; and
adjusting the current values of the network parameters by performing an iteration of a neural network training procedure to optimize an objective function that penalizes the neural network for generating embeddings that do not result in, for each possible clustering assignment other than the ground truth assignment, the oracle clustering score being higher than a clustering score for the possible clustering assignment by at least a structured margin between the possible clustering assignment and the ground truth assignment.
2. The method of claim 1, wherein the structured margin measures a quality of the possible clustering assignment relative to the ground truth assignment.
3. The method of claim 2, wherein the structured margin is based on a normalized mutual information measure between the possible clustering assignment and the ground truth assignment.
4. The method of any one of claims 1-3, wherein performing the iteration of the neural network training procedure comprises:
determining the possible clustering assignment other than the ground truth assignment that has a highest augmented clustering score, wherein the augmented clustering score is the clustering score for the possible clustering assignment plus the structured margin between the possible clustering assignment and the ground truth assignment.
5. The method of claim 4, wherein determining the possible clustering assignment other than the ground truth assignment that has the highest augmented clustering score comprises:
determining an initial best possible clustering assignment using an iterative loss augmented inference technique.
6. The method of claim 5, wherein determining the initial best possible clustering assignment using the iterative loss augmented inference technique comprises:
at each step of the inference technique, adding a medoid to the clustering assignment that most increases the augmented clustering score.
7. The method of any one of claims 5 or 6, wherein determining the possible clustering assignment other than the ground truth assignment that has the highest augmented clustering score comprises:
modifying the initial best possible clustering assignment using a loss augmented refinement technique to determine the possible clustering assignment other than the ground truth assignment that has the highest augmented clustering score.
8. The method of claim 7, wherein modifying the initial best possible clustering assignment using the loss augmented refinement technique comprises:
performing a pairwise exchange of a current medoid in a current best possible clustering assignment with an alternative point in the same cluster according to the current best possible clustering assignment; and
swapping the current medoid for the alternative point if the pairwise exchange increases the augmented clustering score.
9. The method of any one of claims 4-8, wherein performing the iteration of the neural network training procedure further comprises:
determining a gradient of the objective function with respect to the network parameters using the possible clustering assignment other than the ground truth assignment that has a highest augmented clustering score; and
determining an update to the current values of the network parameters using the gradient.
10. The method of any one of claims 1-9, further comprising:
determining the oracle clustering score using a facility location scoring function.
1 1. The method of any one of claims 1 -10, wherein the neural network training procedure is stochastic gradient descent.
12. The method of any one of claims 1-11 , further comprising:
providing the trained neural network for use in generating embeddings for new input data items.
13. A system comprising one or more computers and one or more storage devices storing instructions that when executed by the one or more computers cause the one or more computers to perform the operations of the respective method of any one of claims 1 -12.
14. One or more computer storage media storing instructions that when executed by one or more computers cause the one or more computers to perform the operations of the respective method of any one of claims 1 -12.
EP17817337.3A 2016-11-15 2017-11-15 Training neural networks using a clustering loss Active EP3542319B1 (en)

Applications Claiming Priority (2)

Application Number Priority Date Filing Date Title
US201662422555P 2016-11-15 2016-11-15
PCT/US2017/061853 WO2018093935A1 (en) 2016-11-15 2017-11-15 Training neural networks using a clustering loss

Publications (2)

Publication Number Publication Date
EP3542319A1 true EP3542319A1 (en) 2019-09-25
EP3542319B1 EP3542319B1 (en) 2023-07-26

Family

ID=60702968

Family Applications (1)

Application Number Title Priority Date Filing Date
EP17817337.3A Active EP3542319B1 (en) 2016-11-15 2017-11-15 Training neural networks using a clustering loss

Country Status (4)

Country Link
US (1) US11636314B2 (en)
EP (1) EP3542319B1 (en)
CN (1) CN109983480B (en)
WO (1) WO2018093935A1 (en)

Families Citing this family (13)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US11995537B1 (en) * 2018-03-14 2024-05-28 Perceive Corporation Training network with batches of input instances
US11586902B1 (en) 2018-03-14 2023-02-21 Perceive Corporation Training network to minimize worst case surprise
US11497478B2 (en) * 2018-05-21 2022-11-15 Siemens Medical Solutions Usa, Inc. Tuned medical ultrasound imaging
US20190370651A1 (en) * 2018-06-01 2019-12-05 Nec Laboratories America, Inc. Deep Co-Clustering
EP3770832A1 (en) * 2019-07-23 2021-01-27 Nokia Technologies Oy Workload data
US11657268B1 (en) * 2019-09-27 2023-05-23 Waymo Llc Training neural networks to assign scores
US11651209B1 (en) 2019-10-02 2023-05-16 Google Llc Accelerated embedding layer computations
US10783257B1 (en) * 2019-12-20 2020-09-22 Capital One Services, Llc Use of word embeddings to locate sensitive text in computer programming scripts
WO2021150016A1 (en) 2020-01-20 2021-07-29 Samsung Electronics Co., Ltd. Methods and systems for performing tasks on media using attribute specific joint learning
CN111429887B (en) * 2020-04-20 2023-05-30 合肥讯飞数码科技有限公司 Speech keyword recognition method, device and equipment based on end-to-end
US11854052B2 (en) * 2021-08-09 2023-12-26 Ebay Inc. Forward contracts in e-commerce
CN114897069B (en) * 2022-05-09 2023-04-07 大庆立能电力机械设备有限公司 Intelligent control energy-saving protection device for oil pumping unit
CN116503675B (en) * 2023-06-27 2023-08-29 南京理工大学 Multi-category target identification method and system based on strong clustering loss function

Family Cites Families (20)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US5065339A (en) 1990-05-22 1991-11-12 International Business Machines Corporation Orthogonal row-column neural processor
US6119112A (en) * 1997-11-19 2000-09-12 International Business Machines Corporation Optimum cessation of training in neural networks
US20040024750A1 (en) 2002-07-31 2004-02-05 Ulyanov Sergei V. Intelligent mechatronic control suspension system based on quantum soft computing
US8364639B1 (en) * 2007-10-11 2013-01-29 Parallels IP Holdings GmbH Method and system for creation, analysis and navigation of virtual snapshots
US8086549B2 (en) * 2007-11-09 2011-12-27 Microsoft Corporation Multi-label active learning
US8204838B2 (en) * 2009-04-10 2012-06-19 Microsoft Corporation Scalable clustering
US9082082B2 (en) * 2011-12-06 2015-07-14 The Trustees Of Columbia University In The City Of New York Network information methods devices and systems
PT2639749T (en) 2012-03-15 2017-01-18 Cortical Io Gmbh Methods, apparatus and products for semantic processing of text
US10423889B2 (en) 2013-01-08 2019-09-24 Purepredictive, Inc. Native machine learning integration for a data management product
CN103530689B (en) * 2013-10-31 2016-01-20 中国科学院自动化研究所 A kind of clustering method based on degree of depth study
JP6588449B2 (en) 2014-01-31 2019-10-09 グーグル エルエルシー Generating a vector representation of a document
CN103914735B (en) * 2014-04-17 2017-03-29 北京泰乐德信息技术有限公司 A kind of fault recognition method and system based on Neural Network Self-learning
US10289962B2 (en) * 2014-06-06 2019-05-14 Google Llc Training distilled machine learning models
CN104299035A (en) * 2014-09-29 2015-01-21 国家电网公司 Method for diagnosing fault of transformer on basis of clustering algorithm and neural network
EP3204896A1 (en) * 2014-10-07 2017-08-16 Google, Inc. Training neural networks on partitioned training data
US10387773B2 (en) * 2014-10-27 2019-08-20 Ebay Inc. Hierarchical deep convolutional neural network for image classification
EP3234871B1 (en) * 2014-12-17 2020-11-25 Google LLC Generating numeric embeddings of images
US10628733B2 (en) 2015-04-06 2020-04-21 Deepmind Technologies Limited Selecting reinforcement learning actions using goals and observations
CN104933438A (en) * 2015-06-01 2015-09-23 武艳娇 Image clustering method based on self-coding neural network
CN105701571A (en) * 2016-01-13 2016-06-22 南京邮电大学 Short-term traffic flow prediction method based on nerve network combination model

Also Published As

Publication number Publication date
WO2018093935A1 (en) 2018-05-24
CN109983480B (en) 2023-05-26
US20200065656A1 (en) 2020-02-27
EP3542319B1 (en) 2023-07-26
CN109983480A (en) 2019-07-05
US11636314B2 (en) 2023-04-25

Similar Documents

Publication Publication Date Title
US11636314B2 (en) Training neural networks using a clustering loss
US20210334624A1 (en) Neural architecture search using a performance prediction neural network
US11669744B2 (en) Regularized neural network architecture search
CN109564575B (en) Classifying images using machine learning models
US11681924B2 (en) Training neural networks using a variational information bottleneck
US11544536B2 (en) Hybrid neural architecture search
EP3234871B1 (en) Generating numeric embeddings of images
US11941527B2 (en) Population based training of neural networks
US20200057936A1 (en) Semi-supervised training of neural networks
US20240127058A1 (en) Training neural networks using priority queues
US20200005152A1 (en) Training neural networks using posterior sharpening
US20220230065A1 (en) Semi-supervised training of machine learning models using label guessing
US20220383036A1 (en) Clustering data using neural networks based on normalized cuts
US11907825B2 (en) Training neural networks using distributed batch normalization
US20220253694A1 (en) Training neural networks with reinitialization
US20220335274A1 (en) Multi-stage computationally efficient neural network inference
US20220180147A1 (en) Energy-based associative memory neural networks

Legal Events

Date Code Title Description
STAA Information on the status of an ep patent application or granted ep patent

Free format text: STATUS: UNKNOWN

STAA Information on the status of an ep patent application or granted ep patent

Free format text: STATUS: THE INTERNATIONAL PUBLICATION HAS BEEN MADE

PUAI Public reference made under article 153(3) epc to a published international application that has entered the european phase

Free format text: ORIGINAL CODE: 0009012

STAA Information on the status of an ep patent application or granted ep patent

Free format text: STATUS: REQUEST FOR EXAMINATION WAS MADE

17P Request for examination filed

Effective date: 20190513

AK Designated contracting states

Kind code of ref document: A1

Designated state(s): AL AT BE BG CH CY CZ DE DK EE ES FI FR GB GR HR HU IE IS IT LI LT LU LV MC MK MT NL NO PL PT RO RS SE SI SK SM TR

AX Request for extension of the european patent

Extension state: BA ME

DAV Request for validation of the european patent (deleted)
DAX Request for extension of the european patent (deleted)
STAA Information on the status of an ep patent application or granted ep patent

Free format text: STATUS: EXAMINATION IS IN PROGRESS

17Q First examination report despatched

Effective date: 20210413

STAA Information on the status of an ep patent application or granted ep patent

Free format text: STATUS: EXAMINATION IS IN PROGRESS

REG Reference to a national code

Ref country code: DE

Ref legal event code: R079

Ref document number: 602017071877

Country of ref document: DE

Free format text: PREVIOUS MAIN CLASS: G06N0003080000

Ipc: G06N0003046400

Ref country code: DE

Free format text: PREVIOUS MAIN CLASS: G06N0003080000

GRAP Despatch of communication of intention to grant a patent

Free format text: ORIGINAL CODE: EPIDOSNIGR1

STAA Information on the status of an ep patent application or granted ep patent

Free format text: STATUS: GRANT OF PATENT IS INTENDED

RIC1 Information provided on ipc code assigned before grant

Ipc: G06N 3/0895 20230101ALI20230223BHEP

Ipc: G06N 3/084 20230101ALI20230223BHEP

Ipc: G06N 3/0464 20230101AFI20230223BHEP

INTG Intention to grant announced

Effective date: 20230316

GRAS Grant fee paid

Free format text: ORIGINAL CODE: EPIDOSNIGR3

GRAA (expected) grant

Free format text: ORIGINAL CODE: 0009210

STAA Information on the status of an ep patent application or granted ep patent

Free format text: STATUS: THE PATENT HAS BEEN GRANTED

P01 Opt-out of the competence of the unified patent court (upc) registered

Effective date: 20230527

AK Designated contracting states

Kind code of ref document: B1

Designated state(s): AL AT BE BG CH CY CZ DE DK EE ES FI FR GB GR HR HU IE IS IT LI LT LU LV MC MK MT NL NO PL PT RO RS SE SI SK SM TR

REG Reference to a national code

Ref country code: CH

Ref legal event code: EP

REG Reference to a national code

Ref country code: DE

Ref legal event code: R096

Ref document number: 602017071877

Country of ref document: DE

REG Reference to a national code

Ref country code: IE

Ref legal event code: FG4D

REG Reference to a national code

Ref country code: LT

Ref legal event code: MG9D

REG Reference to a national code

Ref country code: NL

Ref legal event code: MP

Effective date: 20230726

REG Reference to a national code

Ref country code: AT

Ref legal event code: MK05

Ref document number: 1592794

Country of ref document: AT

Kind code of ref document: T

Effective date: 20230726

PG25 Lapsed in a contracting state [announced via postgrant information from national office to epo]

Ref country code: NL

Free format text: LAPSE BECAUSE OF FAILURE TO SUBMIT A TRANSLATION OF THE DESCRIPTION OR TO PAY THE FEE WITHIN THE PRESCRIBED TIME-LIMIT

Effective date: 20230726

PG25 Lapsed in a contracting state [announced via postgrant information from national office to epo]

Ref country code: GR

Free format text: LAPSE BECAUSE OF FAILURE TO SUBMIT A TRANSLATION OF THE DESCRIPTION OR TO PAY THE FEE WITHIN THE PRESCRIBED TIME-LIMIT

Effective date: 20231027

PGFP Annual fee paid to national office [announced via postgrant information from national office to epo]

Ref country code: GB

Payment date: 20231127

Year of fee payment: 7

PG25 Lapsed in a contracting state [announced via postgrant information from national office to epo]

Ref country code: IS

Free format text: LAPSE BECAUSE OF FAILURE TO SUBMIT A TRANSLATION OF THE DESCRIPTION OR TO PAY THE FEE WITHIN THE PRESCRIBED TIME-LIMIT

Effective date: 20231126

PG25 Lapsed in a contracting state [announced via postgrant information from national office to epo]

Ref country code: SE

Free format text: LAPSE BECAUSE OF FAILURE TO SUBMIT A TRANSLATION OF THE DESCRIPTION OR TO PAY THE FEE WITHIN THE PRESCRIBED TIME-LIMIT

Effective date: 20230726

Ref country code: RS

Free format text: LAPSE BECAUSE OF FAILURE TO SUBMIT A TRANSLATION OF THE DESCRIPTION OR TO PAY THE FEE WITHIN THE PRESCRIBED TIME-LIMIT

Effective date: 20230726

Ref country code: PT

Free format text: LAPSE BECAUSE OF FAILURE TO SUBMIT A TRANSLATION OF THE DESCRIPTION OR TO PAY THE FEE WITHIN THE PRESCRIBED TIME-LIMIT

Effective date: 20231127

Ref country code: NO

Free format text: LAPSE BECAUSE OF FAILURE TO SUBMIT A TRANSLATION OF THE DESCRIPTION OR TO PAY THE FEE WITHIN THE PRESCRIBED TIME-LIMIT

Effective date: 20231026

Ref country code: LV

Free format text: LAPSE BECAUSE OF FAILURE TO SUBMIT A TRANSLATION OF THE DESCRIPTION OR TO PAY THE FEE WITHIN THE PRESCRIBED TIME-LIMIT

Effective date: 20230726

Ref country code: LT

Free format text: LAPSE BECAUSE OF FAILURE TO SUBMIT A TRANSLATION OF THE DESCRIPTION OR TO PAY THE FEE WITHIN THE PRESCRIBED TIME-LIMIT

Effective date: 20230726

Ref country code: IS

Free format text: LAPSE BECAUSE OF FAILURE TO SUBMIT A TRANSLATION OF THE DESCRIPTION OR TO PAY THE FEE WITHIN THE PRESCRIBED TIME-LIMIT

Effective date: 20231126

Ref country code: HR

Free format text: LAPSE BECAUSE OF FAILURE TO SUBMIT A TRANSLATION OF THE DESCRIPTION OR TO PAY THE FEE WITHIN THE PRESCRIBED TIME-LIMIT

Effective date: 20230726

Ref country code: GR

Free format text: LAPSE BECAUSE OF FAILURE TO SUBMIT A TRANSLATION OF THE DESCRIPTION OR TO PAY THE FEE WITHIN THE PRESCRIBED TIME-LIMIT

Effective date: 20231027

Ref country code: FI

Free format text: LAPSE BECAUSE OF FAILURE TO SUBMIT A TRANSLATION OF THE DESCRIPTION OR TO PAY THE FEE WITHIN THE PRESCRIBED TIME-LIMIT

Effective date: 20230726

Ref country code: AT

Free format text: LAPSE BECAUSE OF FAILURE TO SUBMIT A TRANSLATION OF THE DESCRIPTION OR TO PAY THE FEE WITHIN THE PRESCRIBED TIME-LIMIT

Effective date: 20230726

PGFP Annual fee paid to national office [announced via postgrant information from national office to epo]

Ref country code: DE

Payment date: 20231129

Year of fee payment: 7

PG25 Lapsed in a contracting state [announced via postgrant information from national office to epo]

Ref country code: PL

Free format text: LAPSE BECAUSE OF FAILURE TO SUBMIT A TRANSLATION OF THE DESCRIPTION OR TO PAY THE FEE WITHIN THE PRESCRIBED TIME-LIMIT

Effective date: 20230726

PG25 Lapsed in a contracting state [announced via postgrant information from national office to epo]

Ref country code: ES

Free format text: LAPSE BECAUSE OF FAILURE TO SUBMIT A TRANSLATION OF THE DESCRIPTION OR TO PAY THE FEE WITHIN THE PRESCRIBED TIME-LIMIT

Effective date: 20230726

REG Reference to a national code

Ref country code: DE

Ref legal event code: R097

Ref document number: 602017071877

Country of ref document: DE

PG25 Lapsed in a contracting state [announced via postgrant information from national office to epo]

Ref country code: SM

Free format text: LAPSE BECAUSE OF FAILURE TO SUBMIT A TRANSLATION OF THE DESCRIPTION OR TO PAY THE FEE WITHIN THE PRESCRIBED TIME-LIMIT

Effective date: 20230726

Ref country code: RO

Free format text: LAPSE BECAUSE OF FAILURE TO SUBMIT A TRANSLATION OF THE DESCRIPTION OR TO PAY THE FEE WITHIN THE PRESCRIBED TIME-LIMIT

Effective date: 20230726

Ref country code: ES

Free format text: LAPSE BECAUSE OF FAILURE TO SUBMIT A TRANSLATION OF THE DESCRIPTION OR TO PAY THE FEE WITHIN THE PRESCRIBED TIME-LIMIT

Effective date: 20230726

Ref country code: EE

Free format text: LAPSE BECAUSE OF FAILURE TO SUBMIT A TRANSLATION OF THE DESCRIPTION OR TO PAY THE FEE WITHIN THE PRESCRIBED TIME-LIMIT

Effective date: 20230726

Ref country code: DK

Free format text: LAPSE BECAUSE OF FAILURE TO SUBMIT A TRANSLATION OF THE DESCRIPTION OR TO PAY THE FEE WITHIN THE PRESCRIBED TIME-LIMIT

Effective date: 20230726

Ref country code: CZ

Free format text: LAPSE BECAUSE OF FAILURE TO SUBMIT A TRANSLATION OF THE DESCRIPTION OR TO PAY THE FEE WITHIN THE PRESCRIBED TIME-LIMIT

Effective date: 20230726

Ref country code: SK

Free format text: LAPSE BECAUSE OF FAILURE TO SUBMIT A TRANSLATION OF THE DESCRIPTION OR TO PAY THE FEE WITHIN THE PRESCRIBED TIME-LIMIT

Effective date: 20230726

PG25 Lapsed in a contracting state [announced via postgrant information from national office to epo]

Ref country code: IT

Free format text: LAPSE BECAUSE OF FAILURE TO SUBMIT A TRANSLATION OF THE DESCRIPTION OR TO PAY THE FEE WITHIN THE PRESCRIBED TIME-LIMIT

Effective date: 20230726

PLBE No opposition filed within time limit

Free format text: ORIGINAL CODE: 0009261

STAA Information on the status of an ep patent application or granted ep patent

Free format text: STATUS: NO OPPOSITION FILED WITHIN TIME LIMIT

REG Reference to a national code

Ref country code: CH

Ref legal event code: PL

PG25 Lapsed in a contracting state [announced via postgrant information from national office to epo]

Ref country code: MC

Free format text: LAPSE BECAUSE OF FAILURE TO SUBMIT A TRANSLATION OF THE DESCRIPTION OR TO PAY THE FEE WITHIN THE PRESCRIBED TIME-LIMIT

Effective date: 20230726

26N No opposition filed

Effective date: 20240429

PG25 Lapsed in a contracting state [announced via postgrant information from national office to epo]

Ref country code: LU

Free format text: LAPSE BECAUSE OF NON-PAYMENT OF DUE FEES

Effective date: 20231115

PG25 Lapsed in a contracting state [announced via postgrant information from national office to epo]

Ref country code: CH

Free format text: LAPSE BECAUSE OF NON-PAYMENT OF DUE FEES

Effective date: 20231130